From bd5772d9142f8bbd93e65a6ca4ec4cb214bbc4ac Mon Sep 17 00:00:00 2001 From: Adam Kelly Date: Tue, 9 Apr 2019 12:03:02 +0100 Subject: [PATCH] Implement dict.__eq__ --- tests/snippets/dict.py | 26 ++++++++++++++------------ vm/src/obj/objdict.rs | 31 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/tests/snippets/dict.py b/tests/snippets/dict.py index 45c6c3b43..6507efb19 100644 --- a/tests/snippets/dict.py +++ b/tests/snippets/dict.py @@ -1,21 +1,23 @@ from testutils import assertRaises -def dict_eq(d1, d2): - return (all(k in d2 and d1[k] == d2[k] for k in d1) - and all(k in d1 and d1[k] == d2[k] for k in d2)) +assert dict(a=2, b=3) == {'a': 2, 'b': 3} +assert dict({'a': 2, 'b': 3}, b=4) == {'a': 2, 'b': 4} +assert dict([('a', 2), ('b', 3)]) == {'a': 2, 'b': 3} - -assert dict_eq(dict(a=2, b=3), {'a': 2, 'b': 3}) -assert dict_eq(dict({'a': 2, 'b': 3}, b=4), {'a': 2, 'b': 4}) -assert dict_eq(dict([('a', 2), ('b', 3)]), {'a': 2, 'b': 3}) +assert {} == {} +assert not {'a': 2} == {} +assert not {} == {'a': 2} +assert not {'b': 2} == {'a': 2} +assert not {'a': 4} == {'a': 2} +assert {'a': 2} == {'a': 2} a = {'g': 5} b = {'a': a, 'd': 9} c = dict(b) c['d'] = 3 c['a']['g'] = 2 -assert dict_eq(a, {'g': 2}) -assert dict_eq(b, {'a': a, 'd': 9}) +assert a == {'g': 2} +assert b == {'a': a, 'd': 9} a.clear() assert len(a) == 0 @@ -142,10 +144,10 @@ assert list(x) == ['a', 'b'] y = x.copy() x['c'] = 12 -assert dict_eq(y, {'a': 2, 'b': 10}) +assert y == {'a': 2, 'b': 10} y.update({'c': 19, "d": -1, 'b': 12}) -assert dict_eq(y, {'a': 2, 'b': 12, 'c': 19, 'd': -1}) +assert y == {'a': 2, 'b': 12, 'c': 19, 'd': -1} y.update(y) -assert dict_eq(y, {'a': 2, 'b': 12, 'c': 19, 'd': -1}) # hasn't changed +assert y == {'a': 2, 'b': 12, 'c': 19, 'd': -1} # hasn't changed diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index 4a503fdcf..8e34ca738 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -7,6 +7,7 @@ use crate::pyobject::{ }; use crate::vm::{ReprGuard, VirtualMachine}; +use super::objbool; use super::objiter; use super::objstr; use crate::dictdatatype; @@ -96,6 +97,35 @@ impl PyDictRef { !self.entries.borrow().is_empty() } + fn inner_eq(self, other: &PyDict, vm: &VirtualMachine) -> PyResult { + if other.entries.borrow().len() != self.entries.borrow().len() { + return Ok(false); + } + for (k, v1) in self { + match other.entries.borrow().get(vm, &k)? { + Some(v2) => { + let value = objbool::boolval(vm, vm._eq(v1, v2)?)?; + if !value { + return Ok(false); + } + } + None => { + return Ok(false); + } + } + } + return Ok(true); + } + + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(other) = other.payload::() { + let eq = self.inner_eq(other, vm)?; + Ok(vm.ctx.new_bool(eq)) + } else { + Ok(vm.ctx.not_implemented()) + } + } + fn len(self, _vm: &VirtualMachine) -> usize { self.entries.borrow().len() } @@ -387,6 +417,7 @@ pub fn init(context: &PyContext) { "__len__" => context.new_rustfunc(PyDictRef::len), "__contains__" => context.new_rustfunc(PyDictRef::contains), "__delitem__" => context.new_rustfunc(PyDictRef::inner_delitem), + "__eq__" => context.new_rustfunc(PyDictRef::eq), "__getitem__" => context.new_rustfunc(PyDictRef::inner_getitem), "__iter__" => context.new_rustfunc(PyDictRef::iter), "__new__" => context.new_rustfunc(PyDictRef::new),