diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 2fdd6925d..72f95f67e 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -37,34 +37,24 @@ impl fmt::Debug for Dict { } } -#[derive(Debug, Copy, Clone)] -enum IndexEntry { - Dummy, - Free, - Index(usize), -} +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +struct IndexEntry(i64); impl IndexEntry { - const FREE: i64 = -1; - const DUMMY: i64 = -2; -} + const FREE: Self = Self(-1); + const DUMMY: Self = Self(-2); -impl From for IndexEntry { - fn from(idx: i64) -> Self { - match idx { - IndexEntry::FREE => IndexEntry::Free, - IndexEntry::DUMMY => IndexEntry::Dummy, - x => IndexEntry::Index(x as usize), - } + unsafe fn from_index_unchecked(idx: usize) -> Self { + debug_assert!((idx as isize) >= 0); + Self(idx as i64) } -} -impl From for i64 { - fn from(idx: IndexEntry) -> Self { - match idx { - IndexEntry::Free => IndexEntry::FREE, - IndexEntry::Dummy => IndexEntry::DUMMY, - IndexEntry::Index(i) => i as i64, + fn index(self) -> Option { + if self.0 >= 0 { + Some(self.0 as usize) + } else { + None } } } @@ -73,7 +63,7 @@ impl From for i64 { struct DictInner { used: usize, filled: usize, - indices: Vec, + indices: Vec, entries: Vec>>, } @@ -175,7 +165,10 @@ impl DictInner { let index_index = idxs.next(); let idx = &mut self.indices[index_index]; if *idx == IndexEntry::FREE { - *idx = entry_idx as i64; + *idx = unsafe { + // entry_idx never grow up to usize::MAX + IndexEntry::from_index_unchecked(entry_idx) + }; entry.index = index_index; break; } @@ -203,9 +196,9 @@ impl DictInner { }; let entry_index = self.entries.len(); self.entries.push(Some(entry)); - self.indices[index] = entry_index as i64; + self.indices[index] = unsafe { IndexEntry::from_index_unchecked(entry_index) }; self.used += 1; - if let IndexEntry::Free = index_entry { + if let IndexEntry::FREE = index_entry { self.filled += 1; if let Some(new_size) = self.should_resize() { self.resize(new_size) @@ -252,7 +245,7 @@ impl Dict { let _removed = loop { let (entry_index, index_index) = self.lookup(vm, key, hash, None)?; let mut inner = self.write(); - if let IndexEntry::Index(index) = entry_index { + if let Some(index) = entry_index.index() { // Update existing key if let Some(entry) = inner.entries.get_mut(index) { let entry = entry @@ -279,7 +272,7 @@ impl Dict { pub fn contains(&self, vm: &VirtualMachine, key: &K) -> PyResult { let (entry, _) = self.lookup(vm, key, key.key_hash(vm)?, None)?; - Ok(matches!(entry, IndexEntry::Index(_))) + Ok(entry.index().is_some()) } /// Retrieve a key @@ -297,7 +290,7 @@ impl Dict { ) -> PyResult> { let ret = loop { let (entry, index_index) = self.lookup(vm, key, hash, None)?; - if let IndexEntry::Index(index) = entry { + if let Some(index) = entry.index() { let inner = self.read(); if let Some(entry) = inner.entries.get(index) { let entry = extract_dict_entry(entry); @@ -385,7 +378,7 @@ impl Dict { let _removed = loop { let lookup = self.lookup(vm, key, hash, None)?; let (entry, index_index) = lookup; - if let IndexEntry::Index(_) = entry { + if entry.index().is_some() { match self.pop_inner(lookup) { ControlFlow::Break(Some(entry)) => break Some(entry), _ => continue, @@ -407,8 +400,8 @@ impl Dict { let hash = key.key_hash(vm)?; let res = loop { let lookup = self.lookup(vm, key, hash, None)?; - let (entry, index_index) = lookup; - if let IndexEntry::Index(index) = entry { + let (index_entry, index_index) = lookup; + if let Some(index) = index_entry.index() { let inner = self.read(); if let Some(entry) = inner.entries.get(index) { let entry = extract_dict_entry(entry); @@ -424,7 +417,13 @@ impl Dict { } else { let value = default(); let mut inner = self.write(); - inner.unchecked_push(index_index, hash, key.to_pyobject(vm), value.clone(), entry); + inner.unchecked_push( + index_index, + hash, + key.to_pyobject(vm), + value.clone(), + index_entry, + ); break value; } }; @@ -445,8 +444,8 @@ impl Dict { let hash = key.key_hash(vm)?; let res = loop { let lookup = self.lookup(vm, key, hash, None)?; - let (entry, index_index) = lookup; - if let IndexEntry::Index(index) = entry { + let (index_entry, index_index) = lookup; + if let Some(index) = index_entry.index() { let inner = self.read(); if let Some(entry) = inner.entries.get(index) { let entry = extract_dict_entry(entry); @@ -464,7 +463,7 @@ impl Dict { let key = key.to_pyobject(vm); let mut inner = self.write(); let ret = (key.clone(), value.clone()); - inner.unchecked_push(index_index, hash, key, value, entry); + inner.unchecked_push(index_index, hash, key, value, index_entry); break ret; } }; @@ -541,22 +540,26 @@ impl Dict { }); loop { let index_index = idxs.next(); - match IndexEntry::from(inner.indices[index_index]) { - IndexEntry::Dummy => { + match inner.indices[index_index] { + IndexEntry::DUMMY => { if freeslot.is_none() { freeslot = Some(index_index); } } - IndexEntry::Free => { + IndexEntry::FREE => { let idxs = match freeslot { - Some(free) => (IndexEntry::Dummy, free), - None => (IndexEntry::Free, index_index), + Some(free) => (IndexEntry::DUMMY, free), + None => (IndexEntry::FREE, index_index), }; return Ok(idxs); } - IndexEntry::Index(i) => { + idx => { + let i = idx.index().unwrap_or_else(|| unsafe { + // DUMMY and FREE is already checked above. + std::hint::unreachable_unchecked() + }); let entry = &inner.entries[i].as_ref().unwrap(); - let ret = (IndexEntry::Index(i), index_index); + let ret = (idx, index_index); if key.key_is(&entry.key) { break 'outer ret; } else if entry.hash == hash_value { @@ -593,7 +596,7 @@ impl Dict { pred: impl Fn(&T) -> Result, ) -> Result, E> { let (entry_index, index_index) = lookup; - let entry_index = if let IndexEntry::Index(entry_index) = entry_index { + let entry_index = if let Some(entry_index) = entry_index.index() { entry_index } else { return Ok(ControlFlow::Break(None));