Make PyDict ThreadSafe

This commit is contained in:
Aviv Palivoda
2020-04-18 16:58:06 +03:00
parent 38cb24df66
commit 08f74bb63c

View File

@@ -1,4 +1,4 @@
use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::fmt;
use super::objiter;
@@ -9,7 +9,7 @@ use crate::exceptions::PyBaseExceptionRef;
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
use crate::pyobject::{
IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyClassImpl, PyContext, PyIterable,
PyObjectRef, PyRef, PyResult, PyValue,
PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe,
};
use crate::vm::{ReprGuard, VirtualMachine};
@@ -20,9 +20,10 @@ pub type DictContentType = dictdatatype::Dict;
#[pyclass]
#[derive(Default)]
pub struct PyDict {
entries: RefCell<DictContentType>,
entries: DictContentType,
}
pub type PyDictRef = PyRef<PyDict>;
impl ThreadSafe for PyDict {}
impl fmt::Debug for PyDict {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -43,7 +44,7 @@ impl PyDictRef {
#[pyslot]
fn tp_new(class: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyDictRef> {
PyDict {
entries: RefCell::new(DictContentType::default()),
entries: DictContentType::default(),
}
.into_ref_with_type(vm, class)
}
@@ -59,7 +60,7 @@ impl PyDictRef {
}
fn merge(
dict: &RefCell<DictContentType>,
dict: &DictContentType,
dict_obj: OptionalArg<PyObjectRef>,
kwargs: KwArgs,
vm: &VirtualMachine,
@@ -68,13 +69,13 @@ impl PyDictRef {
let dicted: Result<PyDictRef, _> = dict_obj.clone().downcast();
if let Ok(dict_obj) = dicted {
for (key, value) in dict_obj {
dict.borrow_mut().insert(vm, &key, value)?;
dict.insert(vm, &key, value)?;
}
} else if let Some(keys) = vm.get_method(dict_obj.clone(), "keys") {
let keys = objiter::get_iter(vm, &vm.invoke(&keys?, vec![])?)?;
while let Some(key) = objiter::get_next_object(vm, &keys)? {
let val = dict_obj.get_item(&key, vm)?;
dict.borrow_mut().insert(vm, &key, val)?;
dict.insert(vm, &key, val)?;
}
} else {
let iter = objiter::get_iter(vm, &dict_obj)?;
@@ -92,14 +93,13 @@ impl PyDictRef {
if objiter::get_next_object(vm, &elem_iter)?.is_some() {
return Err(err(vm));
}
dict.borrow_mut().insert(vm, &key, value)?;
dict.insert(vm, &key, value)?;
}
}
}
let mut dict_borrowed = dict.borrow_mut();
for (key, value) in kwargs.into_iter() {
dict_borrowed.insert(vm, &vm.new_str(key), value)?;
dict.insert(vm, &vm.new_str(key), value)?;
}
Ok(())
}
@@ -111,27 +111,26 @@ impl PyDictRef {
value: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyDictRef> {
let mut dict = DictContentType::default();
let dict = DictContentType::default();
let value = value.unwrap_or_else(|| vm.ctx.none());
for elem in iterable.iter(vm)? {
let elem = elem?;
dict.insert(vm, &elem, value.clone())?;
}
let entries = RefCell::new(dict);
PyDict { entries }.into_ref_with_type(vm, class)
PyDict { entries: dict }.into_ref_with_type(vm, class)
}
#[pymethod(magic)]
fn bool(self) -> bool {
!self.entries.borrow().is_empty()
!self.entries.is_empty()
}
fn inner_eq(self, other: &PyDict, vm: &VirtualMachine) -> PyResult<bool> {
if other.entries.borrow().len() != self.entries.borrow().len() {
if other.entries.len() != self.entries.len() {
return Ok(false);
}
for (k, v1) in self {
match other.entries.borrow().get(vm, &k)? {
match other.entries.get(vm, &k)? {
Some(v2) => {
if v1.is(&v2) {
continue;
@@ -170,12 +169,12 @@ impl PyDictRef {
#[pymethod(magic)]
fn len(self) -> usize {
self.entries.borrow().len()
self.entries.len()
}
#[pymethod(magic)]
fn sizeof(self) -> usize {
size_of::<Self>() + self.entries.borrow().sizeof()
size_of::<Self>() + self.entries.sizeof()
}
#[pymethod(magic)]
@@ -197,17 +196,17 @@ impl PyDictRef {
#[pymethod(magic)]
fn contains(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
self.entries.borrow().contains(vm, &key)
self.entries.contains(vm, &key)
}
#[pymethod(magic)]
fn delitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.entries.borrow_mut().delete(vm, &key)
self.entries.delete(vm, &key)
}
#[pymethod]
fn clear(self) {
self.entries.borrow_mut().clear()
self.entries.clear()
}
#[pymethod(magic)]
@@ -243,7 +242,7 @@ impl PyDictRef {
value: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<()> {
self.entries.borrow_mut().insert(vm, key, value)
self.entries.insert(vm, key, value)
}
#[pymethod(magic)]
@@ -262,7 +261,7 @@ impl PyDictRef {
key: K,
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
if let Some(value) = self.entries.borrow().get(vm, key)? {
if let Some(value) = self.entries.get(vm, key)? {
return Ok(Some(value));
}
@@ -281,7 +280,7 @@ impl PyDictRef {
default: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
match self.entries.borrow().get(vm, &key)? {
match self.entries.get(vm, &key)? {
Some(value) => Ok(value),
None => Ok(default.unwrap_or_else(|| vm.ctx.none())),
}
@@ -294,12 +293,11 @@ impl PyDictRef {
default: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
let mut entries = self.entries.borrow_mut();
match entries.get(vm, &key)? {
match self.entries.get(vm, &key)? {
Some(value) => Ok(value),
None => {
let set_value = default.unwrap_or_else(|| vm.ctx.none());
entries.insert(vm, &key, set_value.clone())?;
self.entries.insert(vm, &key, set_value.clone())?;
Ok(set_value)
}
}
@@ -329,7 +327,7 @@ impl PyDictRef {
default: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
match self.entries.borrow_mut().pop(vm, &key)? {
match self.entries.pop(vm, &key)? {
Some(value) => Ok(value),
None => match default {
OptionalArg::Present(default) => Ok(default),
@@ -340,8 +338,7 @@ impl PyDictRef {
#[pymethod]
fn popitem(self, vm: &VirtualMachine) -> PyResult {
let mut entries = self.entries.borrow_mut();
if let Some((key, value)) = entries.pop_front() {
if let Some((key, value)) = self.entries.pop_front() {
Ok(vm.ctx.new_tuple(vec![key, value]))
} else {
let err_msg = vm.new_str("popitem(): dictionary is empty".to_owned());
@@ -360,14 +357,13 @@ impl PyDictRef {
}
pub fn from_attributes(attrs: PyAttributes, vm: &VirtualMachine) -> PyResult<Self> {
let mut dict = DictContentType::default();
let dict = DictContentType::default();
for (key, value) in attrs {
dict.insert(vm, &vm.ctx.new_str(key), value)?;
}
let entries = RefCell::new(dict);
Ok(PyDict { entries }.into_ref(vm))
Ok(PyDict { entries: dict }.into_ref(vm))
}
#[pymethod(magic)]
@@ -377,11 +373,11 @@ impl PyDictRef {
pub fn contains_key<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> bool {
let key = key.into_pyobject(vm).unwrap();
self.entries.borrow().contains(vm, &key).unwrap()
self.entries.contains(vm, &key).unwrap()
}
pub fn size(&self) -> dictdatatype::DictSize {
self.entries.borrow().size()
self.entries.size()
}
/// This function can be used to get an item without raising the
@@ -487,7 +483,7 @@ impl Iterator for DictIter {
type Item = (PyObjectRef, PyObjectRef);
fn next(&mut self) -> Option<Self::Item> {
match self.dict.entries.borrow().next_entry(&mut self.position) {
match self.dict.entries.next_entry(&mut self.position) {
Some((key, value)) => Some((key, value)),
None => None,
}
@@ -563,13 +559,12 @@ macro_rules! dict_iterator {
#[allow(clippy::redundant_closure_call)]
fn next(&self, vm: &VirtualMachine) -> PyResult {
let mut position = self.position.get();
let dict = self.dict.entries.borrow();
if dict.has_changed_size(&self.size) {
if self.dict.entries.has_changed_size(&self.size) {
return Err(
vm.new_runtime_error("dictionary changed size during iteration".to_owned())
);
}
match dict.next_entry(&mut position) {
match self.dict.entries.next_entry(&mut position) {
Some((key, value)) => {
self.position.set(position);
Ok($result_fn(vm, key, value))
@@ -585,10 +580,7 @@ macro_rules! dict_iterator {
#[pymethod(name = "__length_hint__")]
fn length_hint(&self) -> usize {
self.dict
.entries
.borrow()
.len_from_entry_index(self.position.get())
self.dict.entries.len_from_entry_index(self.position.get())
}
}