Fix thread-safety in GC, type cache, and instruction cache (#7355)

* Fix thread-safety in GC, type cache, and instruction cache

GC / refcount:
- Add safe_inc() check for strong()==0 in RefCount
- Add try_to_owned() to PyObject for atomic refcount acquire
- Replace strong_count()+to_owned() with try_to_owned() in GC
  collection and weakref callback paths to prevent TOCTOU races

Type cache:
- Add proper SeqLock (sequence counter) to TypeCacheEntry
- Readers spin-wait on odd sequence, validate after read
- Writers bracket updates with begin_write/end_write
- Use try_to_owned + pointer revalidation on read path
- Call modified() BEFORE attribute modification in set_attr

Instruction cache:
- Add pointer_cache (AtomicUsize array) to CodeUnits for
  single atomic pointer load/store (prevents torn reads)
- Add try_read_cached_descriptor with try_to_owned + pointer
  and version revalidation after increment
- Add write_cached_descriptor with version-bracketed writes

RLock:
- Fix release() to check is_owned_by_current_thread
- Add _release_save/_acquire_restore methods

* Fix RLock _acquire_restore tuple handling and unxfail threading test

* Align type cache seqlock writer protocol with CPython

* RLock: use single parking_lot level, track recursion manually

Instead of calling lock()/unlock() N times for recursion depth N,
keep parking_lot at 1 level and manage the count ourselves.
This makes acquire/release O(1) and matches CPython's
_PyRecursiveMutex approach (lock once + set level directly).

* Add try_to_owned_from_ptr to avoid &PyObject on stale ptrs

Use addr_of! to access ref_count directly from a raw pointer
without forming &PyObject first. Applied in type cache and
instruction cache hit paths where the pointer may be stale.

* Fix CI: spelling typo and xfail flaky test_thread_safety

- Fix "minimising" -> "minimizing" for cspell
- xfail test_thread_safety: dict iteration races with
  concurrent GC mutations in _finalizer_registry
This commit is contained in:
Jeong, YunWon
2026-03-05 20:33:14 +09:00
committed by GitHub
parent 86134e1ca1
commit 375b5472ed
10 changed files with 374 additions and 181 deletions

View File

@@ -4815,9 +4815,9 @@ class _TestFinalize(BaseTestCase):
result = [obj for obj in iter(conn.recv, 'STOP')]
self.assertEqual(result, ['a', 'b', 'd10', 'd03', 'd02', 'd01', 'e'])
# TODO: RUSTPYTHON - gc.get_threshold() and gc.set_threshold() not implemented
@unittest.expectedFailure
@support.requires_resource('cpu')
# TODO: RUSTPYTHON; dict iteration races with concurrent GC mutations
@unittest.expectedFailure
def test_thread_safety(self):
# bpo-24484: _run_finalizers() should be thread-safe
def cb():

View File

@@ -2141,10 +2141,6 @@ class CRLockTests(lock_tests.RLockTests):
CustomRLock(1, b=2)
self.assertEqual(warnings_log, [])
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_release_save_unacquired(self):
return super().test_release_save_unacquired()
@unittest.skip('TODO: RUSTPYTHON; flaky test')
def test_different_thread(self):
return super().test_different_thread()

View File

@@ -123,7 +123,7 @@ impl RefCount {
pub fn safe_inc(&self) -> bool {
let mut old = State::from_raw(self.state.load(Ordering::Relaxed));
loop {
if old.destructed() {
if old.destructed() || old.strong() == 0 {
return false;
}
if (old.strong() as usize) >= STRONG {

View File

@@ -12,7 +12,7 @@ use core::{
cell::UnsafeCell,
hash, mem,
ops::Deref,
sync::atomic::{AtomicU8, AtomicU16, Ordering},
sync::atomic::{AtomicU8, AtomicU16, AtomicUsize, Ordering},
};
use itertools::Itertools;
use malachite_bigint::BigInt;
@@ -411,6 +411,10 @@ impl TryFrom<&[u8]> for CodeUnit {
pub struct CodeUnits {
units: UnsafeCell<Box<[CodeUnit]>>,
adaptive_counters: Box<[AtomicU16]>,
/// Pointer-sized cache entries for descriptor pointers.
/// Single atomic load/store prevents torn reads when multiple threads
/// specialize the same instruction concurrently.
pointer_cache: Box<[AtomicUsize]>,
}
// SAFETY: All cache operations use atomic read/write instructions.
@@ -432,9 +436,15 @@ impl Clone for CodeUnits {
.iter()
.map(|c| AtomicU16::new(c.load(Ordering::Relaxed)))
.collect();
let pointer_cache = self
.pointer_cache
.iter()
.map(|c| AtomicUsize::new(c.load(Ordering::Relaxed)))
.collect();
Self {
units: UnsafeCell::new(units),
adaptive_counters,
pointer_cache,
}
}
}
@@ -472,13 +482,19 @@ impl<const N: usize> From<[CodeUnit; N]> for CodeUnits {
impl From<Vec<CodeUnit>> for CodeUnits {
fn from(value: Vec<CodeUnit>) -> Self {
let units = value.into_boxed_slice();
let adaptive_counters = (0..units.len())
let len = units.len();
let adaptive_counters = (0..len)
.map(|_| AtomicU16::new(0))
.collect::<Vec<_>>()
.into_boxed_slice();
let pointer_cache = (0..len)
.map(|_| AtomicUsize::new(0))
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
units: UnsafeCell::new(units),
adaptive_counters,
pointer_cache,
}
}
}
@@ -600,25 +616,29 @@ impl CodeUnits {
lo | (hi << 16)
}
/// Write a u64 value across four consecutive CACHE code units starting at `index`.
/// Store a pointer-sized value atomically in the pointer cache at `index`.
///
/// Uses a single `AtomicUsize` store to prevent torn writes when
/// multiple threads specialize the same instruction concurrently.
///
/// # Safety
/// Same requirements as `write_cache_u16`.
pub unsafe fn write_cache_u64(&self, index: usize, value: u64) {
unsafe {
self.write_cache_u32(index, value as u32);
self.write_cache_u32(index + 2, (value >> 32) as u32);
}
/// - `index` must be in bounds.
/// - `value` must be `0` or a valid `*const PyObject` encoded as `usize`.
/// - Callers must follow the cache invalidation/upgrade protocol:
/// invalidate the version guard before writing and publish the new
/// version after writing.
pub unsafe fn write_cache_ptr(&self, index: usize, value: usize) {
self.pointer_cache[index].store(value, Ordering::Relaxed);
}
/// Read a u64 value from four consecutive CACHE code units starting at `index`.
/// Load a pointer-sized value atomically from the pointer cache at `index`.
///
/// Uses a single `AtomicUsize` load to prevent torn reads.
///
/// # Panics
/// Panics if `index + 3` is out of bounds.
pub fn read_cache_u64(&self, index: usize) -> u64 {
let lo = self.read_cache_u32(index) as u64;
let hi = self.read_cache_u32(index + 2) as u64;
lo | (hi << 32)
/// Panics if `index` is out of bounds.
pub fn read_cache_ptr(&self, index: usize) -> usize {
self.pointer_cache[index].load(Ordering::Relaxed)
}
/// Read adaptive counter bits for instruction at `index`.

View File

@@ -64,10 +64,9 @@ static NEXT_TYPE_VERSION: AtomicU32 = AtomicU32::new(1);
// Method cache (type_cache / MCACHE): direct-mapped cache keyed by
// (tp_version_tag, interned_name_ptr).
//
// Uses a lock-free SeqLock pattern:
// - version acts as both cache key AND sequence counter
// - Read: load version (Acquire), read value ptr, re-check version
// - Write: set version=0 (invalidate), store value, store version (Release)
// Uses a lock-free SeqLock pattern for the read/write protocol:
// - Readers validate sequence/version/name before and after the value read.
// - Writers bracket updates with sequence odd/even transitions.
// No mutex needed on the hot path (cache hit).
const TYPE_CACHE_SIZE_EXP: u32 = 12;
@@ -75,6 +74,8 @@ const TYPE_CACHE_SIZE: usize = 1 << TYPE_CACHE_SIZE_EXP;
const TYPE_CACHE_MASK: usize = TYPE_CACHE_SIZE - 1;
struct TypeCacheEntry {
/// Sequence lock (odd = write in progress, even = quiescent).
sequence: AtomicU32,
/// tp_version_tag at cache time. 0 = empty/invalid.
version: AtomicU32,
/// Interned attribute name pointer (pointer equality check).
@@ -94,12 +95,60 @@ unsafe impl Sync for TypeCacheEntry {}
impl TypeCacheEntry {
fn new() -> Self {
Self {
sequence: AtomicU32::new(0),
version: AtomicU32::new(0),
name: AtomicPtr::new(core::ptr::null_mut()),
value: AtomicPtr::new(core::ptr::null_mut()),
}
}
#[inline]
fn begin_write(&self) {
let mut seq = self.sequence.load(Ordering::Acquire);
loop {
while (seq & 1) != 0 {
core::hint::spin_loop();
seq = self.sequence.load(Ordering::Acquire);
}
match self.sequence.compare_exchange_weak(
seq,
seq.wrapping_add(1),
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
core::sync::atomic::fence(Ordering::Release);
break;
}
Err(observed) => {
core::hint::spin_loop();
seq = observed;
}
}
}
}
#[inline]
fn end_write(&self) {
self.sequence.fetch_add(1, Ordering::Release);
}
#[inline]
fn begin_read(&self) -> u32 {
let mut sequence = self.sequence.load(Ordering::Acquire);
while (sequence & 1) != 0 {
core::hint::spin_loop();
sequence = self.sequence.load(Ordering::Acquire);
}
sequence
}
#[inline]
fn end_read(&self, previous: u32) -> bool {
core::sync::atomic::fence(Ordering::Acquire);
self.sequence.load(Ordering::Relaxed) == previous
}
/// Take the value out of this entry, returning the owned PyObjectRef.
/// Caller must ensure no concurrent reads can observe this entry
/// (version should be set to 0 first).
@@ -137,10 +186,14 @@ fn type_cache_clear_version(version: u32) {
let mut to_drop = Vec::new();
for entry in TYPE_CACHE.iter() {
if entry.version.load(Ordering::Relaxed) == version {
entry.version.store(0, Ordering::Release);
if let Some(v) = entry.take_value() {
to_drop.push(v);
entry.begin_write();
if entry.version.load(Ordering::Relaxed) == version {
entry.version.store(0, Ordering::Release);
if let Some(v) = entry.take_value() {
to_drop.push(v);
}
}
entry.end_write();
}
}
drop(to_drop);
@@ -158,10 +211,12 @@ pub fn type_cache_clear() {
// Invalidate all entries and collect values.
let mut to_drop = Vec::new();
for entry in TYPE_CACHE.iter() {
entry.begin_write();
entry.version.store(0, Ordering::Release);
if let Some(v) = entry.take_value() {
to_drop.push(v);
}
entry.end_write();
}
drop(to_drop);
TYPE_CACHE_CLEARING.store(false, Ordering::Release);
@@ -701,8 +756,11 @@ impl PyType {
}
pub fn set_attr(&self, attr_name: &'static PyStrInterned, value: PyObjectRef) {
self.attributes.write().insert(attr_name, value);
// Invalidate caches BEFORE modifying attributes so that cached
// descriptor pointers are still alive when type_cache_clear_version
// drops the cache's strong references.
self.modified();
self.attributes.write().insert(attr_name, value);
}
/// Internal get_attr implementation for fast lookup on a class.
@@ -718,41 +776,43 @@ impl PyType {
/// find_name_in_mro with method cache (MCACHE).
/// Looks in tp_dict of types in MRO, bypasses descriptors.
///
/// Uses a lock-free SeqLock pattern keyed by version:
/// Read: load version → check name → load value → clone → re-check version
/// Write: version=0 → swap value → set name → version=assigned
/// Uses a lock-free SeqLock-style pattern:
/// Read: load sequence/version/name → load value + try_to_owned
/// validate value pointer + sequence
/// Write: sequence(begin) → version=0 → swap value/name → version=assigned → sequence(end)
fn find_name_in_mro(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
let version = self.tp_version_tag.load(Ordering::Acquire);
if version != 0 {
let idx = type_cache_hash(version, name);
let entry = &TYPE_CACHE[idx];
let v1 = entry.version.load(Ordering::Acquire);
if v1 == version
&& core::ptr::eq(
entry.name.load(Ordering::Relaxed),
name as *const _ as *mut _,
)
{
let name_ptr = name as *const _ as *mut _;
loop {
let seq1 = entry.begin_read();
let v1 = entry.version.load(Ordering::Acquire);
let type_version = self.tp_version_tag.load(Ordering::Acquire);
if v1 != type_version
|| !core::ptr::eq(entry.name.load(Ordering::Relaxed), name_ptr)
{
break;
}
let ptr = entry.value.load(Ordering::Acquire);
if !ptr.is_null() {
// SAFETY: The value pointer was stored via PyObjectRef::into_raw
// and is valid as long as the version hasn't changed. We create
// a temporary reference (ManuallyDrop prevents decrement), clone
// it to get our own strong reference, then re-check the version
// to confirm the entry wasn't invalidated during our read.
let cloned = unsafe {
let tmp = core::mem::ManuallyDrop::new(PyObjectRef::from_raw(
NonNull::new_unchecked(ptr),
));
(*tmp).clone()
};
// SeqLock validation: if version changed, discard our clone
let v2 = entry.version.load(Ordering::Acquire);
if v2 == v1 {
if ptr.is_null() {
if entry.end_read(seq1) {
break;
}
continue;
}
// _Py_TryIncrefCompare-style validation:
// safe_inc via raw pointer, then ensure source is unchanged.
if let Some(cloned) = unsafe { PyObject::try_to_owned_from_ptr(ptr) } {
let same_ptr = core::ptr::eq(entry.value.load(Ordering::Relaxed), ptr);
if same_ptr && entry.end_read(seq1) {
return Some(cloned);
}
drop(cloned);
continue;
}
break;
}
}
@@ -777,16 +837,17 @@ impl PyType {
{
let idx = type_cache_hash(assigned, name);
let entry = &TYPE_CACHE[idx];
let name_ptr = name as *const _ as *mut _;
entry.begin_write();
// Invalidate first to prevent readers from seeing partial state
entry.version.store(0, Ordering::Release);
// Swap in new value (refcount held by cache)
let new_ptr = found.clone().into_raw().as_ptr();
let old_ptr = entry.value.swap(new_ptr, Ordering::Relaxed);
entry
.name
.store(name as *const _ as *mut _, Ordering::Relaxed);
entry.name.store(name_ptr, Ordering::Relaxed);
// Activate entry — Release ensures value/name writes are visible
entry.version.store(assigned, Ordering::Release);
entry.end_write();
// Drop previous occupant (its version was already invalidated)
if !old_ptr.is_null() {
unsafe {
@@ -832,20 +893,24 @@ impl PyType {
if version != 0 {
let idx = type_cache_hash(version, name);
let entry = &TYPE_CACHE[idx];
let v1 = entry.version.load(Ordering::Acquire);
if v1 == version
&& core::ptr::eq(
entry.name.load(Ordering::Relaxed),
name as *const _ as *mut _,
)
{
let name_ptr = name as *const _ as *mut _;
loop {
let seq1 = entry.begin_read();
let v1 = entry.version.load(Ordering::Acquire);
let type_version = self.tp_version_tag.load(Ordering::Acquire);
if v1 != type_version
|| !core::ptr::eq(entry.name.load(Ordering::Relaxed), name_ptr)
{
break;
}
let ptr = entry.value.load(Ordering::Acquire);
if !ptr.is_null() {
let v2 = entry.version.load(Ordering::Acquire);
if v2 == v1 {
if entry.end_read(seq1) {
if !ptr.is_null() {
return true;
}
break;
}
continue;
}
}
@@ -1498,8 +1563,8 @@ impl PyType {
PySetterValue::Assign(ref val) => {
let key = identifier!(vm, __type_params__);
self.check_set_special_type_attr(key, vm)?;
self.attributes.write().insert(key, val.clone().into());
self.modified();
self.attributes.write().insert(key, val.clone().into());
}
PySetterValue::Delete => {
// For delete, we still need to check if the type is immutable
@@ -1510,8 +1575,8 @@ impl PyType {
)));
}
let key = identifier!(vm, __type_params__);
self.attributes.write().shift_remove(&key);
self.modified();
self.attributes.write().shift_remove(&key);
}
}
Ok(())

View File

@@ -3257,10 +3257,10 @@ impl ExecutingFrame<'_> {
let owner = self.top_value();
let type_version = self.code.instructions.read_cache_u32(cache_base + 1);
if type_version != 0 && owner.class().tp_version_tag.load(Acquire) == type_version {
// Cache hit: load the cached method descriptor
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let func = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
if type_version != 0
&& owner.class().tp_version_tag.load(Acquire) == type_version
&& let Some(func) = self.try_read_cached_descriptor(cache_base, type_version)
{
let owner = self.pop_value();
self.push_value(func);
self.push_value(owner);
@@ -3287,9 +3287,8 @@ impl ExecutingFrame<'_> {
if type_version != 0
&& owner.class().tp_version_tag.load(Acquire) == type_version
&& owner.dict().is_none()
&& let Some(func) = self.try_read_cached_descriptor(cache_base, type_version)
{
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let func = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
let owner = self.pop_value();
self.push_value(func);
self.push_value(owner);
@@ -3345,10 +3344,10 @@ impl ExecutingFrame<'_> {
false
};
if !shadowed {
// Cache hit: load the cached method descriptor
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let func = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
if !shadowed
&& let Some(func) =
self.try_read_cached_descriptor(cache_base, type_version)
{
let owner = self.pop_value();
self.push_value(func);
self.push_value(owner);
@@ -3475,10 +3474,10 @@ impl ExecutingFrame<'_> {
let owner = self.top_value();
let type_version = self.code.instructions.read_cache_u32(cache_base + 1);
if type_version != 0 && owner.class().tp_version_tag.load(Acquire) == type_version {
// Load cached class attribute directly (no dict, no data descriptor)
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let attr = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
if type_version != 0
&& owner.class().tp_version_tag.load(Acquire) == type_version
&& let Some(attr) = self.try_read_cached_descriptor(cache_base, type_version)
{
self.pop_value();
if oparg.is_method() {
self.push_value(attr);
@@ -3528,8 +3527,17 @@ impl ExecutingFrame<'_> {
return Ok(None);
}
// Not in instance dict — use cached class attr
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let attr = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
let Some(attr) = self.try_read_cached_descriptor(cache_base, type_version)
else {
self.deoptimize_at(
Instruction::LoadAttr {
namei: Arg::marker(),
},
instr_idx,
cache_base,
);
return self.load_attr_slow(vm, oparg);
};
self.pop_value();
if oparg.is_method() {
self.push_value(attr);
@@ -3566,9 +3574,8 @@ impl ExecutingFrame<'_> {
if type_version != 0
&& let Some(owner_type) = owner.downcast_ref::<PyType>()
&& owner_type.tp_version_tag.load(Acquire) == type_version
&& let Some(attr) = self.try_read_cached_descriptor(cache_base, type_version)
{
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let attr = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
self.pop_value();
if oparg.is_method() {
self.push_value(attr);
@@ -3608,9 +3615,8 @@ impl ExecutingFrame<'_> {
&& let Some(owner_type) = owner.downcast_ref::<PyType>()
&& owner_type.tp_version_tag.load(Acquire) == type_version
&& owner.class().tp_version_tag.load(Acquire) == metaclass_version
&& let Some(attr) = self.try_read_cached_descriptor(cache_base, type_version)
{
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
let attr = unsafe { &*(descr_ptr as *const PyObject) }.to_owned();
self.pop_value();
if oparg.is_method() {
self.push_value(attr);
@@ -3683,19 +3689,16 @@ impl ExecutingFrame<'_> {
let owner = self.top_value();
let type_version = self.code.instructions.read_cache_u32(cache_base + 1);
if type_version != 0 && owner.class().tp_version_tag.load(Acquire) == type_version {
let descr_ptr = self.code.instructions.read_cache_u64(cache_base + 5);
if descr_ptr != 0 {
let descr = unsafe { &*(descr_ptr as *const PyObject) };
if let Some(prop) = descr.downcast_ref::<PyProperty>() {
let owner = self.pop_value();
if let Some(getter) = prop.get_fget() {
let result = getter.call((owner,), vm)?;
self.push_value(result);
return Ok(None);
}
}
}
if type_version != 0
&& owner.class().tp_version_tag.load(Acquire) == type_version
&& let Some(descr) = self.try_read_cached_descriptor(cache_base, type_version)
&& let Some(prop) = descr.downcast_ref::<PyProperty>()
&& let Some(getter) = prop.get_fget()
{
let owner = self.pop_value();
let result = getter.call((owner,), vm)?;
self.push_value(result);
return Ok(None);
}
unsafe {
self.code.instructions.replace_op(
@@ -6883,6 +6886,70 @@ impl ExecutingFrame<'_> {
Ok(None)
}
#[inline]
fn try_read_cached_descriptor(
&self,
cache_base: usize,
expected_type_version: u32,
) -> Option<PyObjectRef> {
let descr_ptr = self.code.instructions.read_cache_ptr(cache_base + 5);
if descr_ptr == 0 {
return None;
}
let cloned = unsafe { PyObject::try_to_owned_from_ptr(descr_ptr as *mut PyObject) }?;
if self.code.instructions.read_cache_u32(cache_base + 1) == expected_type_version
&& self.code.instructions.read_cache_ptr(cache_base + 5) == descr_ptr
{
Some(cloned)
} else {
drop(cloned);
None
}
}
#[inline]
unsafe fn write_cached_descriptor(
&self,
cache_base: usize,
type_version: u32,
descr_ptr: usize,
) {
// Publish descriptor cache atomically as a tuple:
// invalidate version first, then write payload, then publish version.
unsafe {
self.code.instructions.write_cache_u32(cache_base + 1, 0);
self.code
.instructions
.write_cache_ptr(cache_base + 5, descr_ptr);
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
}
}
#[inline]
unsafe fn write_cached_descriptor_with_metaclass(
&self,
cache_base: usize,
type_version: u32,
metaclass_version: u32,
descr_ptr: usize,
) {
// Same publish protocol as write_cached_descriptor(), plus metaclass guard.
unsafe {
self.code.instructions.write_cache_u32(cache_base + 1, 0);
self.code
.instructions
.write_cache_u32(cache_base + 3, metaclass_version);
self.code
.instructions
.write_cache_ptr(cache_base + 5, descr_ptr);
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
}
}
fn load_attr(&mut self, vm: &VirtualMachine, oparg: LoadAttr) -> FrameResult {
self.adaptive(|s, ii, cb| s.specialize_load_attr(vm, oparg, ii, cb));
self.load_attr_slow(vm, oparg)
@@ -6973,14 +7040,9 @@ impl ExecutingFrame<'_> {
.flags
.has_feature(PyTypeFlags::METHOD_DESCRIPTOR)
{
let descr_ptr = &**descr as *const PyObject as u64;
let descr_ptr = &**descr as *const PyObject as usize;
unsafe {
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
self.code
.instructions
.write_cache_u64(cache_base + 5, descr_ptr);
self.write_cached_descriptor(cache_base, type_version, descr_ptr);
}
let new_op = if !class_has_dict {
@@ -7032,14 +7094,9 @@ impl ExecutingFrame<'_> {
&& descr.downcast_ref::<PyProperty>().is_some()
{
// Property descriptor — cache the property object pointer
let descr_ptr = &**descr as *const PyObject as u64;
let descr_ptr = &**descr as *const PyObject as usize;
unsafe {
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
self.code
.instructions
.write_cache_u64(cache_base + 5, descr_ptr);
self.write_cached_descriptor(cache_base, type_version, descr_ptr);
}
self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrProperty);
} else {
@@ -7065,14 +7122,9 @@ impl ExecutingFrame<'_> {
} else if class_has_dict {
if let Some(ref descr) = cls_attr {
// Plain class attr + class supports dict — check dict first, fallback
let descr_ptr = &**descr as *const PyObject as u64;
let descr_ptr = &**descr as *const PyObject as usize;
unsafe {
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
self.code
.instructions
.write_cache_u64(cache_base + 5, descr_ptr);
self.write_cached_descriptor(cache_base, type_version, descr_ptr);
}
self.specialize_at(
instr_idx,
@@ -7119,14 +7171,9 @@ impl ExecutingFrame<'_> {
}
} else if let Some(ref descr) = cls_attr {
// No dict support, plain class attr — cache directly
let descr_ptr = &**descr as *const PyObject as u64;
let descr_ptr = &**descr as *const PyObject as usize;
unsafe {
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
self.code
.instructions
.write_cache_u64(cache_base + 5, descr_ptr);
self.write_cached_descriptor(cache_base, type_version, descr_ptr);
}
self.specialize_at(
instr_idx,
@@ -7220,22 +7267,23 @@ impl ExecutingFrame<'_> {
let has_descr_get = descr_class.slots.descr_get.load().is_some();
if !has_descr_get {
// METHOD or NON_DESCRIPTOR — can cache directly
let descr_ptr = &**descr as *const PyObject as u64;
let descr_ptr = &**descr as *const PyObject as usize;
let new_op = if metaclass_version == 0 {
Instruction::LoadAttrClass
} else {
Instruction::LoadAttrClassWithMetaclassCheck
};
unsafe {
self.code
.instructions
.write_cache_u32(cache_base + 1, type_version);
self.code
.instructions
.write_cache_u32(cache_base + 3, metaclass_version);
self.code
.instructions
.write_cache_u64(cache_base + 5, descr_ptr);
if metaclass_version == 0 {
self.write_cached_descriptor(cache_base, type_version, descr_ptr);
} else {
self.write_cached_descriptor_with_metaclass(
cache_base,
type_version,
metaclass_version,
descr_ptr,
);
}
}
self.specialize_at(instr_idx, cache_base, new_op);
return;

View File

@@ -299,13 +299,7 @@ impl GcState {
fn collect_from_list(
list: &LinkedList<GcLink, PyObject>,
) -> impl Iterator<Item = PyObjectRef> + '_ {
list.iter().filter_map(|obj| {
if obj.strong_count() > 0 {
Some(obj.to_owned())
} else {
None
}
})
list.iter().filter_map(|obj| obj.try_to_owned())
}
match generation {
@@ -468,15 +462,16 @@ impl GcState {
// After dropping gen_locks, other threads can untrack+free objects,
// making the raw pointers in `reachable`/`unreachable` dangling.
// Strong refs keep objects alive for later phases.
//
// Use try_to_owned() (CAS-based) instead of strong_count()+to_owned()
// to prevent a TOCTOU race: another thread can dec() the count to 0
// between the check and the increment, causing a use-after-free when
// the destroying thread eventually frees the memory.
let survivor_refs: Vec<PyObjectRef> = reachable
.iter()
.filter_map(|ptr| {
let obj = unsafe { ptr.0.as_ref() };
if obj.strong_count() > 0 {
Some(obj.to_owned())
} else {
None
}
obj.try_to_owned()
})
.collect();
@@ -484,11 +479,7 @@ impl GcState {
.iter()
.filter_map(|ptr| {
let obj = unsafe { ptr.0.as_ref() };
if obj.strong_count() > 0 {
Some(obj.to_owned())
} else {
None
}
obj.try_to_owned()
})
.collect();

View File

@@ -576,11 +576,12 @@ impl WeakRefList {
ptrs.as_mut().set_next(None);
}
// Collect callback if present and weakref is still alive
if wr.0.ref_count.get() > 0 {
// Collect callback only if we can still acquire a strong ref.
if wr.0.ref_count.safe_inc() {
let wr_ref = unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) };
let cb = unsafe { wr.0.payload.callback.get().replace(None) };
if let Some(cb) = cb {
callbacks.push((wr.to_owned(), cb));
callbacks.push((wr_ref, cb));
}
}
@@ -626,11 +627,12 @@ impl WeakRefList {
ptrs.as_mut().set_next(None);
}
// Collect callback without invoking
if wr.0.ref_count.get() > 0 {
// Collect callback without invoking only if we can keep weakref alive.
if wr.0.ref_count.safe_inc() {
let wr_ref = unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) };
let cb = unsafe { wr.0.payload.callback.get().replace(None) };
if let Some(cb) = cb {
callbacks.push((wr.to_owned(), cb));
callbacks.push((wr_ref, cb));
}
}
@@ -660,8 +662,8 @@ impl WeakRefList {
let mut current = NonNull::new(self.head.load(Ordering::Relaxed));
while let Some(node) = current {
let wr = unsafe { node.as_ref() };
if wr.0.ref_count.get() > 0 {
v.push(wr.to_owned());
if wr.0.ref_count.safe_inc() {
v.push(unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) });
}
current = unsafe { WeakLink::pointers(node).as_ref().get_next() };
}
@@ -960,6 +962,46 @@ impl ToOwned for PyObject {
}
}
impl PyObject {
/// Atomically try to create a strong reference.
/// Returns `None` if the strong count is already 0 (object being destroyed).
/// Uses CAS to prevent the TOCTOU race between checking strong_count and
/// incrementing it.
#[inline]
pub fn try_to_owned(&self) -> Option<PyObjectRef> {
if self.0.ref_count.safe_inc() {
Some(PyObjectRef {
ptr: NonNull::from(self),
})
} else {
None
}
}
/// Like [`try_to_owned`](Self::try_to_owned), but from a raw pointer.
///
/// Uses `addr_of!` to access `ref_count` without forming `&PyObject`,
/// minimizing the borrow scope when the pointer may be stale
/// (e.g. cache-hit paths protected by version guards).
///
/// # Safety
/// `ptr` must point to a live (not yet deallocated) `PyObject`, or to
/// memory whose `ref_count` field is still atomically readable
/// (same guarantee as `_Py_TryIncRefShared`).
#[inline]
pub unsafe fn try_to_owned_from_ptr(ptr: *mut Self) -> Option<PyObjectRef> {
let inner = ptr.cast::<PyInner<Erased>>();
let ref_count = unsafe { &*core::ptr::addr_of!((*inner).ref_count) };
if ref_count.safe_inc() {
Some(PyObjectRef {
ptr: unsafe { NonNull::new_unchecked(ptr) },
})
} else {
None
}
}
}
impl PyObjectRef {
#[inline(always)]
pub const fn into_raw(self) -> NonNull<PyObject> {

View File

@@ -221,26 +221,32 @@ pub(crate) mod _thread {
#[pymethod(name = "acquire_lock")]
#[pymethod(name = "__enter__")]
fn acquire(&self, args: AcquireArgs, vm: &VirtualMachine) -> PyResult<bool> {
let result = acquire_lock_impl!(&self.mu, args, vm)?;
if result {
if self.mu.is_owned_by_current_thread() {
// Re-entrant acquisition: just increment our count.
// parking_lot stays at 1 level; we track recursion ourselves.
self.count
.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
return Ok(true);
}
let result = acquire_lock_impl!(&self.mu, args, vm)?;
if result {
self.count.store(1, core::sync::atomic::Ordering::Relaxed);
}
Ok(result)
}
#[pymethod]
#[pymethod(name = "release_lock")]
fn release(&self, vm: &VirtualMachine) -> PyResult<()> {
if !self.mu.is_locked() {
return Err(vm.new_runtime_error("release unlocked lock"));
if !self.mu.is_owned_by_current_thread() {
return Err(vm.new_runtime_error("cannot release un-acquired lock"));
}
debug_assert!(
self.count.load(core::sync::atomic::Ordering::Relaxed) > 0,
"RLock count underflow"
);
self.count
let prev = self
.count
.fetch_sub(1, core::sync::atomic::Ordering::Relaxed);
unsafe { self.mu.unlock() };
debug_assert!(prev > 0, "RLock count underflow");
if prev == 1 {
unsafe { self.mu.unlock() };
}
Ok(())
}
@@ -278,6 +284,35 @@ pub(crate) mod _thread {
}
}
#[pymethod]
fn _release_save(&self, vm: &VirtualMachine) -> PyResult<(usize, u64)> {
if !self.mu.is_owned_by_current_thread() {
return Err(vm.new_runtime_error("cannot release un-acquired lock"));
}
let count = self.count.swap(0, core::sync::atomic::Ordering::Relaxed);
debug_assert!(count > 0, "RLock count underflow");
unsafe { self.mu.unlock() };
Ok((count, current_thread_id()))
}
#[pymethod]
fn _acquire_restore(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let [count_obj, owner_obj] = state.as_slice() else {
return Err(
vm.new_type_error("_acquire_restore() argument 1 must be a 2-item tuple")
);
};
let count: usize = count_obj.clone().try_into_value(vm)?;
let _owner: u64 = owner_obj.clone().try_into_value(vm)?;
if count == 0 {
return Ok(());
}
self.mu.lock();
self.count
.store(count, core::sync::atomic::Ordering::Relaxed);
Ok(())
}
#[pymethod]
fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
self.release(vm)

View File

@@ -22,7 +22,7 @@ use crate::{
self, PyBaseExceptionRef, PyDict, PyDictRef, PyInt, PyList, PyModule, PyStr, PyStrInterned,
PyStrRef, PyTypeRef, PyUtf8Str, PyUtf8StrInterned, PyWeak,
code::PyCode,
dict::{PyDictItems, PyDictKeys, PyDictValues},
dict::{PyDictItems, PyDictValues},
pystr::AsPyStr,
tuple::PyTuple,
},
@@ -1319,10 +1319,6 @@ impl VirtualMachine {
} else if cls.is(self.ctx.types.list_type) {
list_borrow = value.downcast_ref::<PyList>().unwrap().borrow_vec();
&list_borrow
} else if cls.is(self.ctx.types.dict_keys_type) {
// Atomic snapshot of dict keys - prevents race condition during iteration
let keys = value.downcast_ref::<PyDictKeys>().unwrap().dict.keys_vec();
return keys.into_iter().map(func).collect();
} else if cls.is(self.ctx.types.dict_values_type) {
// Atomic snapshot of dict values - prevents race condition during iteration
let values = value