diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index eed66687a..6a9a09d69 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -104,6 +104,20 @@ impl PyRange { let result = start + step * index; Some(result) } + + #[inline] + fn length(&self) -> PyInt { + let start = self.start.as_bigint(); + let stop = self.stop.as_bigint(); + let step = self.step.as_bigint(); + + match step.sign() { + Sign::Plus if start < stop => PyInt::new((stop - start - 1usize) / step + 1), + Sign::Minus if start > stop => PyInt::new((start - stop - 1usize) / (-step) + 1), + Sign::Plus | Sign::Minus => PyInt::new(0), + Sign::NoSign => unreachable!(), + } + } } pub fn get_value(obj: &PyObjectRef) -> PyRange { @@ -201,16 +215,7 @@ impl PyRange { #[pymethod(name = "__len__")] fn len(&self, _vm: &VirtualMachine) -> PyInt { - let start = self.start.as_bigint(); - let stop = self.stop.as_bigint(); - let step = self.step.as_bigint(); - - match step.sign() { - Sign::Plus if start < stop => PyInt::new((stop - start - 1usize) / step + 1), - Sign::Minus if start > stop => PyInt::new((start - stop - 1usize) / (-step) + 1), - Sign::Plus | Sign::Minus => PyInt::new(0), - Sign::NoSign => unreachable!(), - } + self.length() } #[pymethod(name = "__repr__")] @@ -243,9 +248,24 @@ impl PyRange { fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> bool { if objtype::isinstance(&rhs, &vm.ctx.range_type()) { let rhs = get_value(&rhs); - self.start.as_bigint() == rhs.start.as_bigint() - && self.stop.as_bigint() == rhs.stop.as_bigint() - && self.step.as_bigint() == rhs.step.as_bigint() + + if self.length().as_bigint() != rhs.length().as_bigint() { + return false; + } + + if self.length().as_bigint().is_zero() { + return true; + } + + if self.start.as_bigint() != rhs.start.as_bigint() { + return false; + } + let step = self.step.as_bigint(); + if step.is_one() || step == rhs.step.as_bigint() { + return true; + } + + false } else { false }