From 67bc687b31f07d6d7dc805e4b626e5b0c0f6e336 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Sat, 5 Oct 2019 17:24:09 +0900 Subject: [PATCH 1/2] Add comparison of slice Add gt, ge, lt, le for slice Issue #1431 --- vm/src/obj/objslice.rs | 82 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index 32542c5fd..e0bd27e72 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -129,6 +129,48 @@ impl PySlice { Ok(true) } + fn inner_lte(&self, other: &PySlice, eq: bool, vm: &VirtualMachine) -> PyResult { + if vm.bool_lt(self.start(vm), other.start(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.start(vm), other.start(vm))? { + return Ok(false); + } + + if vm.bool_lt(self.stop(vm), other.stop(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.stop(vm), other.stop(vm))? { + return Ok(false); + } + + if vm.bool_lt(self.step(vm), other.step(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.step(vm), other.step(vm))? { + return Ok(false); + } + Ok(eq) + } + + fn inner_gte(&self, other: &PySlice, eq: bool, vm: &VirtualMachine) -> PyResult { + if vm.bool_gt(self.start(vm), other.start(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.start(vm), other.start(vm))? { + return Ok(false); + } + + if vm.bool_gt(self.stop(vm), other.stop(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.stop(vm), other.stop(vm))? { + return Ok(false); + } + + if vm.bool_gt(self.step(vm), other.step(vm))? { + return Ok(true); + } else if !vm.bool_eq(self.step(vm), other.step(vm))? { + return Ok(false); + } + Ok(eq) + } + #[pymethod(name = "__eq__")] fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Some(rhs) = rhs.payload::() { @@ -148,6 +190,46 @@ impl PySlice { Ok(vm.ctx.not_implemented()) } } + + #[pymethod(name = "__lt__")] + fn lt(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let lt = self.inner_lte(rhs, false, vm)?; + Ok(vm.ctx.new_bool(lt)) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + #[pymethod(name = "__gt__")] + fn gt(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let gt = self.inner_gte(rhs, false, vm)?; + Ok(vm.ctx.new_bool(gt)) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + #[pymethod(name = "__ge__")] + fn ge(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let ge = self.inner_gte(rhs, true, vm)?; + Ok(vm.ctx.new_bool(ge)) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + #[pymethod(name = "__le__")] + fn le(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(rhs) = rhs.payload::() { + let le = self.inner_lte(rhs, true, vm)?; + Ok(vm.ctx.new_bool(le)) + } else { + Ok(vm.ctx.not_implemented()) + } + } } fn to_index_value(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { From a6a996d1caf6f3a7cfe2e1b7714697bf0997b6d3 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Sat, 5 Oct 2019 17:30:34 +0900 Subject: [PATCH 2/2] Add tests for comparison of slice --- tests/snippets/slice.py | 52 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/snippets/slice.py b/tests/snippets/slice.py index a76a8f733..53bd226c8 100644 --- a/tests/snippets/slice.py +++ b/tests/snippets/slice.py @@ -1,3 +1,4 @@ +from testutils import assert_raises a = slice(10) assert a.start == None @@ -30,4 +31,53 @@ assert slice(0, 0, 1).__ne__(slice(0, 0, 0)) assert slice(0).__eq__(0) == NotImplemented assert slice(0).__ne__(0) == NotImplemented -assert slice(None).__ne__(slice(0)) \ No newline at end of file +assert slice(None).__ne__(slice(0)) + +# slice gt, ge, lt, le +assert_raises(TypeError, lambda: slice(0, slice(), 0) < slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, slice(), 0) <= slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, slice(), 0) > slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, slice(), 0) >= slice(0, 0, 0)) + +assert_raises(TypeError, lambda: slice(0, 0, 0) < slice(0, 0, slice())) +assert_raises(TypeError, lambda: slice(0, 0, 0) <= slice(0, 0, slice())) +assert_raises(TypeError, lambda: slice(0, 0, 0) > slice(0, 0, slice())) +assert_raises(TypeError, lambda: slice(0, 0, 0) >= slice(0, 0, slice())) + +assert_raises(TypeError, lambda: slice(0, 0) >= slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, 0) <= slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, 0) < slice(0, 0, 0)) +assert_raises(TypeError, lambda: slice(0, 0) > slice(0, 0, 0)) + +assert slice(0, 0, 0) < slice(0, 1, -1) +assert slice(0, 0, 0) < slice(0, 0, 1) +assert slice(0, 0, 0) > slice(0, 0, -1) +assert slice(0, 0, 0) >= slice(0, 0, -1) +assert not slice(0, 0, 0) <= slice(0, 0, -1) + +assert slice(0, 0, 0) > slice(0, -1, 1) +assert slice(0, 0, 0) >= slice(0, -1, 1) +assert slice(0, 0, 0) >= slice(0, -1, 1) + +assert slice(0, 0, 0) <= slice(0, 0, 1) +assert slice(0, 0, 0) <= slice(0, 0, 0) +assert slice(0, 0, 0) <= slice(0, 0, 0) +assert not slice(0, 0, 0) > slice(0, 0, 0) +assert not slice(0, 0, 0) < slice(0, 0, 0) + +assert not slice(0, float('nan'), float('nan')) <= slice(0, float('nan'), 1) +assert not slice(0, float('nan'), float('nan')) <= slice(0, float('nan'), float('nan')) +assert not slice(0, float('nan'), float('nan')) >= slice(0, float('nan'), float('nan')) +assert not slice(0, float('nan'), float('nan')) < slice(0, float('nan'), float('nan')) +assert not slice(0, float('nan'), float('nan')) > slice(0, float('nan'), float('nan')) + +assert slice(0, float('inf'), float('inf')) >= slice(0, float('inf'), 1) +assert slice(0, float('inf'), float('inf')) <= slice(0, float('inf'), float('inf')) +assert slice(0, float('inf'), float('inf')) >= slice(0, float('inf'), float('inf')) +assert not slice(0, float('inf'), float('inf')) < slice(0, float('inf'), float('inf')) +assert not slice(0, float('inf'), float('inf')) > slice(0, float('inf'), float('inf')) + +assert_raises(TypeError, lambda: slice(0) < 3) +assert_raises(TypeError, lambda: slice(0) > 3) +assert_raises(TypeError, lambda: slice(0) <= 3) +assert_raises(TypeError, lambda: slice(0) >= 3)