diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index a0d9210b5..e80fa7d8a 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -130,16 +130,20 @@ impl PyDictRef { self.entries.borrow_mut().clear() } - fn iter(self, vm: &VirtualMachine) -> PyDictKeysIteratorRef { - PyDictKeysIteratorRef::new(self, vm) + fn iter(self, _vm: &VirtualMachine) -> PyDictKeyIterator { + PyDictKeyIterator::new(self) } - fn values(self, vm: &VirtualMachine) -> PyDictValuesIteratorRef { - PyDictValuesIteratorRef::new(self, vm) + fn keys(self, _vm: &VirtualMachine) -> PyDictKeys { + PyDictKeys::new(self) } - fn items(self, vm: &VirtualMachine) -> PyDictItemsIteratorRef { - PyDictItemsIteratorRef::new(self, vm) + fn values(self, _vm: &VirtualMachine) -> PyDictValues { + PyDictValues::new(self) + } + + fn items(self, _vm: &VirtualMachine) -> PyDictItems { + PyDictItems::new(self) } pub fn get_key_value_pairs(&self) -> Vec<(PyObjectRef, PyObjectRef)> { @@ -243,116 +247,89 @@ impl ItemProtocol for PyDictRef { } } -#[derive(Debug)] -struct PyDictKeysIterator { - pub dict: PyDictRef, - pub position: Cell, -} -type PyDictKeysIteratorRef = PyRef; - -impl PyDictKeysIteratorRef { - fn new(dict: PyDictRef, vm: &VirtualMachine) -> PyDictKeysIteratorRef { - PyDictKeysIterator { - position: Cell::new(0), - dict, +macro_rules! dict_iterator { + ( $name: ident, $iter_name: ident, $class: ident, $iter_class: ident, $result_fn: expr) => { + #[derive(Debug)] + struct $name { + pub dict: PyDictRef, } - .into_ref(vm) - } - fn next(self, vm: &VirtualMachine) -> PyResult { - match self.dict.entries.borrow().next_entry(self.position.get()) { - Some((new_position, key, _value)) => { - self.position.set(new_position); - Ok(key.clone()) + impl $name { + fn new(dict: PyDictRef) -> Self { + $name { dict: dict } } - None => Err(objiter::new_stop_iteration(vm)), - } - } - fn iter(self, _vm: &VirtualMachine) -> Self { - self - } -} - -impl PyValue for PyDictKeysIterator { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.ctx.dictkeysiterator_type.clone() - } -} - -#[derive(Debug)] -struct PyDictValuesIterator { - pub dict: PyDictRef, - pub position: Cell, -} -type PyDictValuesIteratorRef = PyRef; - -impl PyDictValuesIteratorRef { - fn new(dict: PyDictRef, vm: &VirtualMachine) -> PyDictValuesIteratorRef { - PyDictValuesIterator { - position: Cell::new(0), - dict, - } - .into_ref(vm) - } - - fn next(self, vm: &VirtualMachine) -> PyResult { - match self.dict.entries.borrow().next_entry(self.position.get()) { - Some((new_position, _key, value)) => { - self.position.set(new_position); - Ok(value.clone()) + fn iter(&self, _vm: &VirtualMachine) -> $iter_name { + $iter_name::new(self.dict.clone()) } - None => Err(objiter::new_stop_iteration(vm)), } - } - fn iter(self, _vm: &VirtualMachine) -> Self { - self - } -} - -impl PyValue for PyDictValuesIterator { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.ctx.dictvaluesiterator_type.clone() - } -} - -#[derive(Debug)] -struct PyDictItemsIterator { - pub dict: PyDictRef, - pub position: Cell, -} - -type PyDictItemsIteratorRef = PyRef; - -impl PyDictItemsIteratorRef { - fn new(dict: PyDictRef, vm: &VirtualMachine) -> PyDictItemsIteratorRef { - PyDictItemsIterator { - position: Cell::new(0), - dict, - } - .into_ref(vm) - } - - fn next(self: PyDictItemsIteratorRef, vm: &VirtualMachine) -> PyResult { - match self.dict.entries.borrow().next_entry(self.position.get()) { - Some((new_position, key, value)) => { - self.position.set(new_position); - Ok(vm.ctx.new_tuple(vec![key.clone(), value.clone()])) + impl PyValue for $name { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.$class.clone() } - None => Err(objiter::new_stop_iteration(vm)), } - } - fn iter(self, _vm: &VirtualMachine) -> Self { - self - } + #[derive(Debug)] + struct $iter_name { + pub dict: PyDictRef, + pub position: Cell, + } + + impl $iter_name { + fn new(dict: PyDictRef) -> Self { + $iter_name { + position: Cell::new(0), + dict, + } + } + + fn next(&self, vm: &VirtualMachine) -> PyResult { + match self.dict.entries.borrow().next_entry(self.position.get()) { + Some((new_position, key, value)) => { + self.position.set(new_position); + Ok($result_fn(vm, key, value)) + } + None => Err(objiter::new_stop_iteration(vm)), + } + } + + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } + } + + impl PyValue for $iter_name { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.$iter_class.clone() + } + } + }; } -impl PyValue for PyDictItemsIterator { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.ctx.dictitemsiterator_type.clone() - } +dict_iterator! { + PyDictKeys, + PyDictKeyIterator, + dictkeys_type, + dictkeyiterator_type, + |_vm: &VirtualMachine, key: &PyObjectRef, _value: &PyObjectRef| key.clone() +} + +dict_iterator! { + PyDictValues, + PyDictValueIterator, + dictvalues_type, + dictvalueiterator_type, + |_vm: &VirtualMachine, _key: &PyObjectRef, value: &PyObjectRef| value.clone() +} + +dict_iterator! { + PyDictItems, + PyDictItemIterator, + dictitems_type, + dictitemiterator_type, + |vm: &VirtualMachine, key: &PyObjectRef, value: &PyObjectRef| + vm.ctx.new_tuple(vec![key.clone(), value.clone()]) } pub fn init(context: &PyContext) { @@ -371,24 +348,36 @@ pub fn init(context: &PyContext) { "values" => context.new_rustfunc(PyDictRef::values), "items" => context.new_rustfunc(PyDictRef::items), // TODO: separate type. `keys` should be a live view over the collection, not an iterator. - "keys" => context.new_rustfunc(PyDictRef::iter), + "keys" => context.new_rustfunc(PyDictRef::keys), "get" => context.new_rustfunc(PyDictRef::get), "copy" => context.new_rustfunc(PyDictRef::copy), "update" => context.new_rustfunc(PyDictRef::update), }); - extend_class!(context, &context.dictkeysiterator_type, { - "__next__" => context.new_rustfunc(PyDictKeysIteratorRef::next), - "__iter__" => context.new_rustfunc(PyDictKeysIteratorRef::iter), + extend_class!(context, &context.dictkeys_type, { + "__iter__" => context.new_rustfunc(PyDictKeys::iter), }); - extend_class!(context, &context.dictvaluesiterator_type, { - "__next__" => context.new_rustfunc(PyDictValuesIteratorRef::next), - "__iter__" => context.new_rustfunc(PyDictValuesIteratorRef::iter), + extend_class!(context, &context.dictkeyiterator_type, { + "__next__" => context.new_rustfunc(PyDictKeyIterator::next), + "__iter__" => context.new_rustfunc(PyDictKeyIterator::iter), }); - extend_class!(context, &context.dictitemsiterator_type, { - "__next__" => context.new_rustfunc(PyDictItemsIteratorRef::next), - "__iter__" => context.new_rustfunc(PyDictItemsIteratorRef::iter), + extend_class!(context, &context.dictvalues_type, { + "__iter__" => context.new_rustfunc(PyDictValues::iter), + }); + + extend_class!(context, &context.dictvalueiterator_type, { + "__next__" => context.new_rustfunc(PyDictValueIterator::next), + "__iter__" => context.new_rustfunc(PyDictValueIterator::iter), + }); + + extend_class!(context, &context.dictitems_type, { + "__iter__" => context.new_rustfunc(PyDictItems::iter), + }); + + extend_class!(context, &context.dictitemiterator_type, { + "__next__" => context.new_rustfunc(PyDictItemIterator::next), + "__iter__" => context.new_rustfunc(PyDictItemIterator::iter), }); } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 8c405c5b1..434693ea6 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -129,9 +129,12 @@ pub struct PyContext { pub false_value: PyIntRef, pub list_type: PyClassRef, pub listiterator_type: PyClassRef, - pub dictkeysiterator_type: PyClassRef, - pub dictvaluesiterator_type: PyClassRef, - pub dictitemsiterator_type: PyClassRef, + pub dictkeyiterator_type: PyClassRef, + pub dictvalueiterator_type: PyClassRef, + pub dictitemiterator_type: PyClassRef, + pub dictkeys_type: PyClassRef, + pub dictvalues_type: PyClassRef, + pub dictitems_type: PyClassRef, pub map_type: PyClassRef, pub memoryview_type: PyClassRef, pub none: PyNoneRef, @@ -257,9 +260,12 @@ impl PyContext { let str_type = create_type("str", &type_type, &object_type); let list_type = create_type("list", &type_type, &object_type); let listiterator_type = create_type("list_iterator", &type_type, &object_type); - let dictkeysiterator_type = create_type("dict_keys", &type_type, &object_type); - let dictvaluesiterator_type = create_type("dict_values", &type_type, &object_type); - let dictitemsiterator_type = create_type("dict_items", &type_type, &object_type); + let dictkeys_type = create_type("dict_keys", &type_type, &object_type); + let dictvalues_type = create_type("dict_values", &type_type, &object_type); + let dictitems_type = create_type("dict_items", &type_type, &object_type); + let dictkeyiterator_type = create_type("dict_keyiterator", &type_type, &object_type); + let dictvalueiterator_type = create_type("dict_valueiterator", &type_type, &object_type); + let dictitemiterator_type = create_type("dict_itemiterator", &type_type, &object_type); let set_type = create_type("set", &type_type, &object_type); let frozenset_type = create_type("frozenset", &type_type, &object_type); let int_type = create_type("int", &type_type, &object_type); @@ -319,9 +325,12 @@ impl PyContext { staticmethod_type, list_type, listiterator_type, - dictkeysiterator_type, - dictvaluesiterator_type, - dictitemsiterator_type, + dictkeys_type, + dictvalues_type, + dictitems_type, + dictkeyiterator_type, + dictvalueiterator_type, + dictitemiterator_type, set_type, frozenset_type, true_value,