optimize range.__contains__ and iter_search

This commit is contained in:
Jeong, YunWon
2024-05-09 01:01:07 +09:00
committed by Jeong, YunWon
parent fe06583643
commit 52ce1509a5

View File

@@ -32,15 +32,15 @@ enum SearchType {
// and place in vm.rs for all sequences to be able to use it.
#[inline]
fn iter_search(
obj: PyObjectRef,
item: PyObjectRef,
obj: &PyObject,
item: &PyObject,
flag: SearchType,
vm: &VirtualMachine,
) -> PyResult<usize> {
let mut count = 0;
let iter = obj.get_iter(vm)?;
for element in iter.iter_without_hint::<PyObjectRef>(vm)? {
if vm.bool_eq(&item, &*element?)? {
if vm.bool_eq(item, &*element?)? {
match flag {
SearchType::Index => return Ok(count),
SearchType::Contains => return Ok(1),
@@ -53,8 +53,7 @@ fn iter_search(
SearchType::Contains => Ok(0),
SearchType::Index => Err(vm.new_value_error(format!(
"{} not in range",
&item
.repr(vm)
item.repr(vm)
.map(|v| v.as_str().to_owned())
.unwrap_or_else(|_| "value".to_owned())
))),
@@ -174,7 +173,15 @@ pub fn init(context: &Context) {
PyRangeIterator::extend_class(context, context.types.range_iterator_type);
}
#[pyclass(with(AsMapping, AsSequence, Hashable, Comparable, Iterable, Representable))]
#[pyclass(with(
Py,
AsMapping,
AsSequence,
Hashable,
Comparable,
Iterable,
Representable
))]
impl PyRange {
fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
PyRange {
@@ -265,26 +272,6 @@ impl PyRange {
!self.is_empty()
}
#[pymethod(magic)]
fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> bool {
// Only accept ints, not subclasses.
if let Some(int) = needle.payload_if_exact::<PyInt>(vm) {
match self.offset(int.as_bigint()) {
Some(ref offset) => offset.is_multiple_of(self.step.as_bigint()),
None => false,
}
} else {
iter_search(
self.clone().into_pyobject(vm),
needle,
SearchType::Contains,
vm,
)
.unwrap_or(0)
!= 0
}
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
let range_parameters: Vec<PyObjectRef> = [&self.start, &self.stop, &self.step]
@@ -295,43 +282,6 @@ impl PyRange {
(vm.ctx.types.range_type.to_owned(), range_parameters_tuple)
}
#[pymethod]
fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
if let Ok(int) = needle.clone().downcast::<PyInt>() {
match self.index_of(int.as_bigint()) {
Some(idx) => Ok(idx),
None => Err(vm.new_value_error(format!("{int} is not in range"))),
}
} else {
// Fallback to iteration.
Ok(BigInt::from_bytes_be(
Sign::Plus,
&iter_search(
self.clone().into_pyobject(vm),
needle,
SearchType::Index,
vm,
)?
.to_be_bytes(),
))
}
}
#[pymethod]
fn count(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
if let Ok(int) = item.clone().downcast::<PyInt>() {
if self.index_of(int.as_bigint()).is_some() {
Ok(1)
} else {
Ok(0)
}
} else {
// Dealing with classes who might compare equal with ints in their
// __eq__, slow search.
iter_search(self.clone().into_pyobject(vm), item, SearchType::Count, vm)
}
}
#[pymethod(magic)]
fn getitem(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match RangeIndex::try_from_object(vm, subscript)? {
@@ -374,6 +324,58 @@ impl PyRange {
}
}
#[pyclass]
impl Py<PyRange> {
fn contains_inner(&self, needle: &PyObject, vm: &VirtualMachine) -> bool {
// Only accept ints, not subclasses.
if let Some(int) = needle.downcast_ref_if_exact::<PyInt>(vm) {
match self.offset(int.as_bigint()) {
Some(ref offset) => offset.is_multiple_of(self.step.as_bigint()),
None => false,
}
} else {
iter_search(self.as_object(), needle, SearchType::Contains, vm).unwrap_or(0) != 0
}
}
#[pymethod(magic)]
fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> bool {
self.contains_inner(&needle, vm)
}
#[pymethod]
fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
if let Ok(int) = needle.clone().downcast::<PyInt>() {
match self.index_of(int.as_bigint()) {
Some(idx) => Ok(idx),
None => Err(vm.new_value_error(format!("{int} is not in range"))),
}
} else {
// Fallback to iteration.
Ok(BigInt::from_bytes_be(
Sign::Plus,
&iter_search(self.as_object(), &needle, SearchType::Index, vm)?.to_be_bytes(),
))
}
}
#[pymethod]
fn count(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
if let Ok(int) = item.clone().downcast::<PyInt>() {
let count = if self.index_of(int.as_bigint()).is_some() {
1
} else {
0
};
Ok(count)
} else {
// Dealing with classes who might compare equal with ints in their
// __eq__, slow search.
iter_search(self.as_object(), &item, SearchType::Count, vm)
}
}
}
impl PyRange {
fn protocol_length(&self, vm: &VirtualMachine) -> PyResult<usize> {
PyInt::from(self.len())
@@ -408,7 +410,7 @@ impl AsSequence for PyRange {
.ok_or_else(|| vm.new_index_error("index out of range".to_owned()))
}),
contains: atomic_func!(|seq, needle, vm| {
Ok(PyRange::sequence_downcast(seq).contains(needle.to_owned(), vm))
Ok(PyRange::sequence_downcast(seq).contains_inner(needle, vm))
}),
..PySequenceMethods::NOT_IMPLEMENTED
});