diff --git a/stdlib/src/array.rs b/stdlib/src/array.rs index 3241e2444d..c002bed8e2 100644 --- a/stdlib/src/array.rs +++ b/stdlib/src/array.rs @@ -23,7 +23,7 @@ mod array { BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, PyIterReturn, PyMappingMethods, }, - sequence::{SequenceExt, SequenceMutExt}, + sequence::{OptionalRangeArgs, SequenceExt, SequenceMutExt}, sliceable::{ SaturatedSlice, SequenceIndex, SequenceIndexOp, SliceableSequenceMutOp, SliceableSequenceOp, @@ -867,17 +867,10 @@ mod array { fn index( &self, x: PyObjectRef, - start: OptionalArg, - stop: OptionalArg, + range: OptionalRangeArgs, vm: &VirtualMachine, ) -> PyResult { - let len = self.len(); - let saturate = |obj: PyObjectRef, len| -> PyResult<_> { - obj.try_into_value(vm) - .map(|int: PyIntRef| int.as_bigint().saturated_at(len)) - }; - let start = start.map_or(Ok(0), |obj| saturate(obj, len))?; - let stop = stop.map_or(Ok(len), |obj| saturate(obj, len))?; + let (start, stop) = range.saturate(self.len(), vm)?; self.read().index(x, start, stop, vm) } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 19162b2f3f..0ab5eae69c 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -1,4 +1,4 @@ -use super::{PositionIterInternal, PyGenericAlias, PyIntRef, PyTupleRef, PyType, PyTypeRef}; +use super::{PositionIterInternal, PyGenericAlias, PyTupleRef, PyType, PyTypeRef}; use crate::common::lock::{ PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; @@ -9,8 +9,8 @@ use crate::{ iter::PyExactSizeIterator, protocol::{PyIterReturn, PyMappingMethods, PySequence, PySequenceMethods}, recursion::ReprGuard, - sequence::{MutObjectSequenceOp, SequenceExt, SequenceMutExt}, - sliceable::{SequenceIndex, SequenceIndexOp, SliceableSequenceMutOp, SliceableSequenceOp}, + sequence::{MutObjectSequenceOp, OptionalRangeArgs, SequenceExt, SequenceMutExt}, + sliceable::{SequenceIndex, SliceableSequenceMutOp, SliceableSequenceOp}, types::{ AsMapping, AsSequence, Comparable, Constructor, Hashable, Initializer, IterNext, IterNextIterable, Iterable, PyComparisonOp, Unconstructible, Unhashable, @@ -269,17 +269,10 @@ impl PyList { fn index( &self, needle: PyObjectRef, - start: OptionalArg, - stop: OptionalArg, + range: OptionalRangeArgs, vm: &VirtualMachine, ) -> PyResult { - let len = self.len(); - let saturate = |obj: PyObjectRef, len| -> PyResult<_> { - obj.try_into_value(vm) - .map(|int: PyIntRef| int.as_bigint().saturated_at(len)) - }; - let start = start.map_or(Ok(0), |obj| saturate(obj, len))?; - let stop = stop.map_or(Ok(len), |obj| saturate(obj, len))?; + let (start, stop) = range.saturate(self.len(), vm)?; let index = self.mut_index_range(vm, &needle, start..stop)?; if let Some(index) = index.into() { Ok(index) diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 8e3c11fef7..ce6d0cfbec 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -1,4 +1,4 @@ -use super::{PositionIterInternal, PyGenericAlias, PyIntRef, PyType, PyTypeRef}; +use super::{PositionIterInternal, PyGenericAlias, PyType, PyTypeRef}; use crate::common::{hash::PyHash, lock::PyMutex}; use crate::{ class::PyClassImpl, @@ -7,8 +7,7 @@ use crate::{ iter::PyExactSizeIterator, protocol::{PyIterReturn, PyMappingMethods, PySequenceMethods}, recursion::ReprGuard, - sequence::SequenceExt, - sliceable::SequenceIndexOp, + sequence::{OptionalRangeArgs, SequenceExt}, sliceable::{SequenceIndex, SliceableSequenceOp}, types::{ AsMapping, AsSequence, Comparable, Constructor, Hashable, IterNext, IterNextIterable, @@ -281,17 +280,10 @@ impl PyTuple { fn index( &self, needle: PyObjectRef, - start: OptionalArg, - stop: OptionalArg, + range: OptionalRangeArgs, vm: &VirtualMachine, ) -> PyResult { - let len = self.len(); - let saturate = |obj: PyObjectRef, len| -> PyResult<_> { - obj.try_into_value(vm) - .map(|int: PyIntRef| int.as_bigint().saturated_at(len)) - }; - let start = start.map_or(Ok(0), |i| saturate(i, len))?; - let stop = stop.map_or(Ok(len), |i| saturate(i, len))?; + let (start, stop) = range.saturate(self.len(), vm)?; for (index, element) in self .elements .iter() diff --git a/vm/src/sequence.rs b/vm/src/sequence.rs index b7eb01964a..dc7a3b86ee 100644 --- a/vm/src/sequence.rs +++ b/vm/src/sequence.rs @@ -1,5 +1,7 @@ use crate::{ - function::{Either, PyComparisonValue}, + builtins::PyIntRef, + function::{Either, OptionalArg, PyComparisonValue}, + sliceable::SequenceIndexOp, types::{richcompare_wrapper, PyComparisonOp, RichCompareFunc}, vm::VirtualMachine, AsObject, PyObject, PyObjectRef, PyResult, @@ -225,3 +227,23 @@ impl SequenceMutExt for Vec { self } } + +#[derive(FromArgs)] +pub struct OptionalRangeArgs { + #[pyarg(positional, optional)] + start: OptionalArg, + #[pyarg(positional, optional)] + stop: OptionalArg, +} + +impl OptionalRangeArgs { + pub fn saturate(self, len: usize, vm: &VirtualMachine) -> PyResult<(usize, usize)> { + let saturate = |obj: PyObjectRef| -> PyResult<_> { + obj.try_into_value(vm) + .map(|int: PyIntRef| int.as_bigint().saturated_at(len)) + }; + let start = self.start.map_or(Ok(0), saturate)?; + let stop = self.stop.map_or(Ok(len), saturate)?; + Ok((start, stop)) + } +} diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 34cadd9b0c..a34c29afe4 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -5,14 +5,14 @@ mod _collections { use crate::{ builtins::{ IterStatus::{Active, Exhausted}, - PositionIterInternal, PyGenericAlias, PyInt, PyIntRef, PyTypeRef, + PositionIterInternal, PyGenericAlias, PyInt, PyTypeRef, }, common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}, function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue}, iter::PyExactSizeIterator, protocol::{PyIterReturn, PySequenceMethods}, recursion::ReprGuard, - sequence::MutObjectSequenceOp, + sequence::{MutObjectSequenceOp, OptionalRangeArgs}, sliceable::SequenceIndexOp, types::{ AsSequence, Comparable, Constructor, Hashable, Initializer, IterNext, IterNextIterable, @@ -157,19 +157,12 @@ mod _collections { fn index( &self, needle: PyObjectRef, - start: OptionalArg, - stop: OptionalArg, + range: OptionalRangeArgs, vm: &VirtualMachine, ) -> PyResult { let start_state = self.state.load(); - let len = self.len(); - let saturate = |obj: PyObjectRef, len| -> PyResult<_> { - obj.try_into_value(vm) - .map(|int: PyIntRef| int.as_bigint().saturated_at(len)) - }; - let start = start.map_or(Ok(0), |i| saturate(i, len))?; - let stop = stop.map_or(Ok(len), |i| saturate(i, len))?; + let (start, stop) = range.saturate(self.len(), vm)?; let index = self.mut_index_range(vm, &needle, start..stop)?; if start_state != self.state.load() { Err(vm.new_runtime_error("deque mutated during iteration".to_owned()))