diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 41576c8a5b..25b80fa7e9 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -1,4 +1,4 @@ -use crate::obj::objstr::PyString; +use crate::obj::objstr::{PyString, PyStringRef}; use crate::pyhash; use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult}; use crate::vm::VirtualMachine; @@ -438,6 +438,26 @@ impl DictKey for &PyObjectRef { } } +impl DictKey for &PyStringRef { + fn do_hash(self, _vm: &VirtualMachine) -> PyResult { + Ok(self.hash()) + } + + fn do_is(self, other: &PyObjectRef) -> bool { + self.is(other) + } + + fn do_eq(self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult { + if self.is(other_key) { + Ok(true) + } else if let Some(py_str_value) = other_key.payload::() { + Ok(py_str_value.as_str() == self.as_str()) + } else { + vm.bool_eq(self.clone().into_object(), other_key.clone()) + } + } +} + /// Implement trait for the str type, so that we can use strings /// to index dictionaries. impl DictKey for &str { diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 4a6c16c989..7bbff6b9b9 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -292,7 +292,7 @@ impl PyString { } #[pymethod(name = "__hash__")] - fn hash(&self) -> pyhash::PyHash { + pub(crate) fn hash(&self) -> pyhash::PyHash { self.hash.load().unwrap_or_else(|| { let hash = pyhash::hash_value(&self.value); self.hash.store(Some(hash));