From ec5fd550debd3c2bbdc646215484dfc81fb3de3f Mon Sep 17 00:00:00 2001 From: Adam Kelly Date: Tue, 9 Apr 2019 11:25:18 +0100 Subject: [PATCH] Implement IntoIterator for PyDictRef. --- vm/src/dictdatatype.rs | 12 ------- vm/src/frame.rs | 15 ++++----- vm/src/obj/objdict.rs | 73 ++++++++++++++++++++++++++++++----------- vm/src/obj/objmodule.rs | 6 +--- vm/src/obj/objobject.rs | 2 +- vm/src/stdlib/json.rs | 2 +- 6 files changed, 63 insertions(+), 47 deletions(-) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 0612bc8a3..6dc49125b 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -131,18 +131,6 @@ impl Dict { None } - pub fn iter_items(&self) -> impl Iterator + '_ { - self.entries - .iter() - .filter(|e| e.is_some()) - .map(|e| e.as_ref().unwrap()) - .map(|e| (e.key.clone(), e.value.clone())) - } - - pub fn get_items(&self) -> Vec<(PyObjectRef, T)> { - self.iter_items().collect() - } - /// Lookup the index for the given key. fn lookup(&self, vm: &VirtualMachine, key: &PyObjectRef) -> PyResult { let hash_value = calc_hash(vm, key)?; diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 2570c0934..ff9baba35 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -393,9 +393,8 @@ impl Frame { // Take all key-value pairs from the dict: let dict: PyDictRef = obj.downcast().expect("Need a dictionary to build a map."); - let dict_elements = dict.get_key_value_pairs(); - for (key, value) in dict_elements.iter() { - map_obj.set_item(key.clone(), value.clone(), vm).unwrap(); + for (key, value) in dict { + map_obj.set_item(key, value, vm).unwrap(); } } } else { @@ -636,8 +635,7 @@ impl Frame { let kwargs = if *has_kwargs { let kw_dict: PyDictRef = self.pop_value().downcast().expect("Kwargs must be a dict."); - let dict_elements = kw_dict.get_key_value_pairs(); - dict_elements + kw_dict .into_iter() .map(|elem| (objstr::get_value(&elem.0), elem.1)) .collect() @@ -862,8 +860,8 @@ impl Frame { // Grab all the names from the module and put them in the context if let Some(dict) = &module.dict { - for (k, v) in dict.get_key_value_pairs().iter() { - self.scope.store_name(&vm, &objstr::get_value(k), v.clone()); + for (k, v) in dict { + self.scope.store_name(&vm, &objstr::get_value(&k), v); } } Ok(None) @@ -1230,8 +1228,7 @@ impl fmt::Debug for Frame { .collect::(); let dict = self.scope.get_locals(); let local_str = dict - .get_key_value_pairs() - .iter() + .into_iter() .map(|elem| format!("\n {:?} = {:?}", elem.0, elem.1)) .collect::(); write!( diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index 53445e01f..b0d99d4e7 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -3,8 +3,7 @@ use std::fmt; use crate::function::{KwArgs, OptionalArg}; use crate::pyobject::{ - IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyContext, PyObjectRef, PyRef, PyResult, - PyValue, + IntoPyObject, ItemProtocol, PyAttributes, PyContext, PyObjectRef, PyRef, PyResult, PyValue, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -63,9 +62,8 @@ impl PyDictRef { if let OptionalArg::Present(dict_obj) = dict_obj { let dicted: PyResult = dict_obj.clone().downcast(); if let Ok(dict_obj) = dicted { - let mut dict_borrowed = dict.borrow_mut(); - for (key, value) in dict_obj.entries.borrow().iter_items() { - dict_borrowed.insert(vm, &key, value)?; + for (key, value) in dict_obj { + dict.borrow_mut().insert(vm, &key, value)?; } } else { let iter = objiter::get_iter(vm, &dict_obj)?; @@ -105,7 +103,7 @@ impl PyDictRef { fn repr(self, vm: &VirtualMachine) -> PyResult { let s = if let Some(_guard) = ReprGuard::enter(self.as_object()) { let mut str_parts = vec![]; - for (key, value) in self.get_key_value_pairs() { + for (key, value) in self { let key_repr = vm.to_repr(&key)?; let value_repr = vm.to_repr(&value)?; str_parts.push(format!("{}: {}", key_repr.value, value_repr.value)); @@ -146,10 +144,6 @@ impl PyDictRef { PyDictItems::new(self) } - pub fn get_key_value_pairs(&self) -> Vec<(PyObjectRef, PyObjectRef)> { - self.entries.borrow().get_items() - } - fn inner_setitem( self, key: PyObjectRef, @@ -194,24 +188,17 @@ impl PyDictRef { fn update( self, - mut dict_obj: OptionalArg, + dict_obj: OptionalArg, kwargs: KwArgs, vm: &VirtualMachine, ) -> PyResult<()> { - if let OptionalArg::Present(ref other) = dict_obj { - if self.is(other) { - // Updating yourself is a noop, and this avoids a borrow error - dict_obj = OptionalArg::Missing; - } - } - PyDictRef::merge(&self.entries, dict_obj, kwargs, vm) } /// Take a python dictionary and convert it to attributes. pub fn to_attributes(self) -> PyAttributes { let mut attrs = PyAttributes::new(); - for (key, value) in self.get_key_value_pairs() { + for (key, value) in self { let key = objstr::get_value(&key); attrs.insert(key, value); } @@ -247,6 +234,54 @@ impl ItemProtocol for PyDictRef { } } +// Implement IntoIterator so that we can easily iterate dictionaries from rust code. +impl IntoIterator for PyDictRef { + type Item = (PyObjectRef, PyObjectRef); + type IntoIter = DictIterator; + + fn into_iter(self) -> Self::IntoIter { + DictIterator::new(self) + } +} + +impl IntoIterator for &PyDictRef { + type Item = (PyObjectRef, PyObjectRef); + type IntoIter = DictIterator; + + fn into_iter(self) -> Self::IntoIter { + DictIterator::new(self.clone()) + } +} + +pub struct DictIterator { + dict: PyDictRef, + position: Cell, +} + +impl DictIterator { + pub fn new(dict: PyDictRef) -> DictIterator { + DictIterator { + dict, + position: Cell::new(0), + } + } +} + +impl Iterator for DictIterator { + type Item = (PyObjectRef, PyObjectRef); + + fn next(&mut self) -> Option { + self.dict + .entries + .borrow() + .next_entry(self.position.get()) + .map(|(new_position, key, value)| { + self.position.set(new_position); + (key.clone(), value.clone()) + }) + } +} + macro_rules! dict_iterator { ( $name: ident, $iter_name: ident, $class: ident, $iter_class: ident, $class_name: literal, $iter_class_name: literal, $result_fn: expr) => { #[pyclass(name = $class_name, __inside_vm)] diff --git a/vm/src/obj/objmodule.rs b/vm/src/obj/objmodule.rs index 75749d557..904487e4d 100644 --- a/vm/src/obj/objmodule.rs +++ b/vm/src/obj/objmodule.rs @@ -17,11 +17,7 @@ impl PyValue for PyModule { impl PyModuleRef { fn dir(self: PyModuleRef, vm: &VirtualMachine) -> PyResult { if let Some(dict) = &self.into_object().dict { - let keys = dict - .get_key_value_pairs() - .iter() - .map(|(k, _v)| k.clone()) - .collect(); + let keys = dict.into_iter().map(|(k, _v)| k.clone()).collect(); Ok(vm.ctx.new_list(keys)) } else { panic!("Modules should definitely have a dict."); diff --git a/vm/src/obj/objobject.rs b/vm/src/obj/objobject.rs index a3bc41103..f8e86d0a1 100644 --- a/vm/src/obj/objobject.rs +++ b/vm/src/obj/objobject.rs @@ -237,7 +237,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> PyAttributes { // Get instance attributes: if let Some(dict) = &obj.dict { - for (key, value) in dict.get_key_value_pairs() { + for (key, value) in dict.into_iter() { attributes.insert(key.to_string(), value.clone()); } } diff --git a/vm/src/stdlib/json.rs b/vm/src/stdlib/json.rs index f4061bc2a..3fa30e16b 100644 --- a/vm/src/stdlib/json.rs +++ b/vm/src/stdlib/json.rs @@ -64,7 +64,7 @@ impl<'s> serde::Serialize for PyObjectSerializer<'s> { serialize_seq_elements(serializer, &elements) } else if objtype::isinstance(self.pyobject, &self.vm.ctx.dict_type()) { let dict: PyDictRef = self.pyobject.clone().downcast().unwrap(); - let pairs = dict.get_key_value_pairs(); + let pairs: Vec<(PyObjectRef, PyObjectRef)> = dict.into_iter().collect(); let mut map = serializer.serialize_map(Some(pairs.len()))?; for (key, e) in pairs.iter() { map.serialize_entry(&self.clone_with_object(key), &self.clone_with_object(&e))?;