diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index e0918bd3f..90b92b430 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -58,6 +58,12 @@ assert range(2) == range(0, 2) assert range(0, 10, 3) == range(0, 12, 3) assert range(20, 10, 3) == range(20, 12, 3) +assert range(10).__eq__(range(0, 10, 1)) is True +assert range(10).__ne__(range(0, 10, 1)) is False +assert range(10).__eq__(range(0, 11, 1)) is False +assert range(10).__ne__(range(0, 11, 1)) is True +assert range(0, 10, 3).__eq__(range(0, 11, 3)) is True +assert range(0, 10, 3).__ne__(range(0, 11, 3)) is False #__lt__ assert range(1, 2, 3).__lt__(range(1, 2, 3)) == NotImplemented assert range(1, 2, 1).__lt__(range(1, 2)) == NotImplemented diff --git a/tests/snippets/floats.py b/tests/snippets/floats.py index 73bc2bbc4..5cdbda547 100644 --- a/tests/snippets/floats.py +++ b/tests/snippets/floats.py @@ -229,3 +229,34 @@ assert float('nan').hex() == 'nan' # Test float exponent: assert 1 if 1else 0 == 1 +a = 3. +assert a.__eq__(3) is True +assert a.__eq__(3.) is True +assert a.__eq__(3.00000) is True +assert a.__eq__(3.01) is False + +pi = 3.14 +assert pi.__eq__(3.14) is True +assert pi.__ne__(3.14) is False +assert pi.__eq__(3) is False +assert pi.__ne__(3) is True +assert pi.__eq__('pi') is NotImplemented +assert pi.__ne__('pi') is NotImplemented + +assert pi.__eq__(float('inf')) is False +assert pi.__ne__(float('inf')) is True +assert float('inf').__eq__(pi) is False +assert float('inf').__ne__(pi) is True +assert float('inf').__eq__(float('inf')) is True +assert float('inf').__ne__(float('inf')) is False +assert float('inf').__eq__(float('nan')) is False +assert float('inf').__ne__(float('nan')) is True + +assert pi.__eq__(float('nan')) is False +assert pi.__ne__(float('nan')) is True +assert float('nan').__eq__(pi) is False +assert float('nan').__ne__(pi) is True +assert float('nan').__eq__(float('nan')) is False +assert float('nan').__ne__(float('nan')) is True +assert float('nan').__eq__(float('inf')) is False +assert float('nan').__ne__(float('inf')) is True diff --git a/tests/snippets/none.py b/tests/snippets/none.py index a77079162..b0080a9d2 100644 --- a/tests/snippets/none.py +++ b/tests/snippets/none.py @@ -18,3 +18,9 @@ assert none() is none2() assert str(None) == 'None' assert repr(None) == 'None' assert type(None)() is None + +assert None.__eq__(3) is NotImplemented +assert None.__ne__(3) is NotImplemented +assert None.__eq__(None) is True +assert None.__ne__(None) is False + diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index fbcc6fcc0..eeef545c6 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -179,20 +179,39 @@ impl PyFloat { PyFloat::from(float_val?).into_ref_with_type(vm, cls) } + fn float_eq(&self, other: PyObjectRef) -> bool { + let other = get_value(&other); + self.value == other + } + + fn int_eq(&self, other: PyObjectRef) -> bool { + let other_int = objint::get_value(&other); + let value = self.value; + if let (Some(self_int), Some(other_float)) = (value.to_bigint(), other_int.to_f64()) { + value == other_float && self_int == *other_int + } else { + false + } + } + #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - let value = self.value; let result = if objtype::isinstance(&other, &vm.ctx.float_type()) { - let other = get_value(&other); - value == other + self.float_eq(other) } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - let other_int = objint::get_value(&other); + self.int_eq(other) + } else { + return vm.ctx.not_implemented(); + }; + vm.ctx.new_bool(result) + } - if let (Some(self_int), Some(other_float)) = (value.to_bigint(), other_int.to_f64()) { - value == other_float && self_int == *other_int - } else { - false - } + #[pymethod(name = "__ne__")] + fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + let result = if objtype::isinstance(&other, &vm.ctx.float_type()) { + !self.float_eq(other) + } else if objtype::isinstance(&other, &vm.ctx.int_type()) { + !self.int_eq(other) } else { return vm.ctx.not_implemented(); }; diff --git a/vm/src/obj/objnone.rs b/vm/src/obj/objnone.rs index 1a6412327..4407d0287 100644 --- a/vm/src/obj/objnone.rs +++ b/vm/src/obj/objnone.rs @@ -111,6 +111,24 @@ impl PyNoneRef { Err(vm.new_attribute_error(format!("{} has no attribute '{}'", self.as_object(), name))) } } + + #[pymethod(name = "__eq__")] + fn eq(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if vm.is_none(&rhs) { + vm.ctx.new_bool(true) + } else { + vm.ctx.not_implemented() + } + } + + #[pymethod(name = "__ne__")] + fn ne(self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if vm.is_none(&rhs) { + vm.ctx.new_bool(false) + } else { + vm.ctx.not_implemented() + } + } } pub fn init(context: &PyContext) { diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index e71d03255..799242255 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -14,7 +14,7 @@ use crate::vm::VirtualMachine; use super::objint::{PyInt, PyIntRef}; use super::objiter; use super::objslice::{PySlice, PySliceRef}; -use super::objtype::{self, PyClassRef}; +use super::objtype::PyClassRef; /// range(stop) -> range object /// range(start, stop[, step]) -> range object @@ -239,30 +239,43 @@ impl PyRange { } } + fn inner_eq(&self, rhs: &PyRange) -> bool { + if self.length() != rhs.length() { + return false; + } + + if self.length().is_zero() { + return true; + } + + if self.start.as_bigint() != rhs.start.as_bigint() { + return false; + } + let step = self.step.as_bigint(); + if step.is_one() || step == rhs.step.as_bigint() { + return true; + } + + false + } + #[pymethod(name = "__eq__")] - fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> bool { - if objtype::isinstance(&rhs, &vm.ctx.range_type()) { - let rhs = get_value(&rhs); - - if self.length() != rhs.length() { - return false; - } - - if self.length().is_zero() { - return true; - } - - if self.start.as_bigint() != rhs.start.as_bigint() { - return false; - } - let step = self.step.as_bigint(); - if step.is_one() || step == rhs.step.as_bigint() { - return true; - } - - false + fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let eq = self.inner_eq(rhs); + Ok(vm.ctx.new_bool(eq)) } else { - false + Ok(vm.ctx.not_implemented()) + } + } + + #[pymethod(name = "__ne__")] + fn ne(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let eq = self.inner_eq(rhs); + Ok(vm.ctx.new_bool(!eq)) + } else { + Ok(vm.ctx.not_implemented()) } }