mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Merge pull request #1666 from RustPython/coolreader18/length_hint
Add __length_hint__ support for iterators
This commit is contained in:
@@ -254,3 +254,16 @@ assert c == {'bla', 'c', 'd', 'f'}
|
||||
assert not {}.__ne__({})
|
||||
assert {}.__ne__({'a':'b'})
|
||||
assert {}.__ne__(1) == NotImplemented
|
||||
|
||||
it = iter({0: 1, 2: 3, 4:5, 6:7})
|
||||
assert it.__length_hint__() == 4
|
||||
next(it)
|
||||
assert it.__length_hint__() == 3
|
||||
next(it)
|
||||
assert it.__length_hint__() == 2
|
||||
next(it)
|
||||
assert it.__length_hint__() == 1
|
||||
next(it)
|
||||
assert it.__length_hint__() == 0
|
||||
assert_raises(StopIteration, next, it)
|
||||
assert it.__length_hint__() == 0
|
||||
|
||||
@@ -22,6 +22,7 @@ use crate::obj::objdict::PyDictRef;
|
||||
use crate::obj::objfunction::PyFunctionRef;
|
||||
use crate::obj::objint::{self, PyIntRef};
|
||||
use crate::obj::objiter;
|
||||
use crate::obj::objsequence;
|
||||
use crate::obj::objstr::{PyString, PyStringRef};
|
||||
use crate::obj::objtype::{self, PyClassRef};
|
||||
use crate::pyhash;
|
||||
@@ -384,11 +385,8 @@ fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
objiter::get_iter(vm, &iter_target)
|
||||
}
|
||||
|
||||
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
let method = vm.get_method_or_type_error(obj.clone(), "__len__", || {
|
||||
format!("object of type '{}' has no len()", obj.class().name)
|
||||
})?;
|
||||
vm.invoke(&method, PyFuncArgs::default())
|
||||
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
objsequence::len(&obj, vm)
|
||||
}
|
||||
|
||||
fn builtin_locals(vm: &VirtualMachine) -> PyDictRef {
|
||||
|
||||
@@ -204,26 +204,26 @@ impl<T: Clone> Dict<T> {
|
||||
}
|
||||
|
||||
pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(&PyObjectRef, &T)> {
|
||||
while *position < self.entries.len() {
|
||||
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
|
||||
*position += 1;
|
||||
return Some((key, value));
|
||||
}
|
||||
self.entries[*position..].iter().find_map(|entry| {
|
||||
*position += 1;
|
||||
}
|
||||
None
|
||||
entry
|
||||
.as_ref()
|
||||
.map(|DictEntry { key, value, .. }| (key, value))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn len_from_entry_index(&self, position: EntryIndex) -> usize {
|
||||
self.entries[position..].iter().flatten().count()
|
||||
}
|
||||
|
||||
pub fn has_changed_size(&self, position: &DictSize) -> bool {
|
||||
position.size != self.size || self.entries.len() != position.entries_size
|
||||
}
|
||||
|
||||
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = PyObjectRef> + 'a> {
|
||||
Box::new(
|
||||
self.entries
|
||||
.iter()
|
||||
.filter_map(|v| v.as_ref().map(|v| v.key.clone())),
|
||||
)
|
||||
pub fn keys<'a>(&'a self) -> impl Iterator<Item = PyObjectRef> + 'a {
|
||||
self.entries
|
||||
.iter()
|
||||
.filter_map(|v| v.as_ref().map(|v| v.key.clone()))
|
||||
}
|
||||
|
||||
/// Lookup the index for the given key.
|
||||
|
||||
@@ -250,8 +250,6 @@ pub fn write_exception_inner<W: Write>(
|
||||
write_traceback_entry(output, traceback)?;
|
||||
tb = traceback.next.as_ref();
|
||||
}
|
||||
} else {
|
||||
writeln!(output, "No traceback set on exception")?;
|
||||
}
|
||||
|
||||
let varargs = exc.args();
|
||||
|
||||
@@ -545,6 +545,14 @@ macro_rules! dict_iterator {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
|
||||
self.dict
|
||||
.entries
|
||||
.borrow()
|
||||
.len_from_entry_index(self.position.get())
|
||||
}
|
||||
}
|
||||
|
||||
impl PyValue for $iter_name {
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
* Various types to support iteration.
|
||||
*/
|
||||
|
||||
use num_traits::{Signed, ToPrimitive};
|
||||
use std::cell::Cell;
|
||||
|
||||
use super::objint::PyInt;
|
||||
use super::objsequence;
|
||||
use super::objtype::{self, PyClassRef};
|
||||
use crate::exceptions::PyBaseExceptionRef;
|
||||
use crate::pyobject::{
|
||||
@@ -61,10 +64,12 @@ pub fn get_next_object(
|
||||
|
||||
/* Retrieve all elements from an iterator */
|
||||
pub fn get_all<T: TryFromObject>(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult<Vec<T>> {
|
||||
let mut elements = vec![];
|
||||
let cap = length_hint(vm, iter_obj.clone())?.unwrap_or(0);
|
||||
let mut elements = Vec::with_capacity(cap);
|
||||
while let Some(element) = get_next_object(vm, iter_obj)? {
|
||||
elements.push(T::try_from_object(vm, element)?);
|
||||
}
|
||||
elements.shrink_to_fit();
|
||||
Ok(elements)
|
||||
}
|
||||
|
||||
@@ -83,6 +88,49 @@ pub fn stop_iter_value(vm: &VirtualMachine, exc: &PyBaseExceptionRef) -> PyResul
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult<Option<usize>> {
|
||||
if let Some(len) = objsequence::opt_len(&iter, vm) {
|
||||
match len {
|
||||
Ok(len) => return Ok(Some(len)),
|
||||
Err(e) => {
|
||||
if !objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let hint = match vm.get_method(iter, "__length_hint__") {
|
||||
Some(hint) => hint?,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let result = match vm.invoke(&hint, vec![]) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
if objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
|
||||
return Ok(None);
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
let result = result
|
||||
.payload_if_subclass::<PyInt>(vm)
|
||||
.ok_or_else(|| {
|
||||
vm.new_type_error(format!(
|
||||
"'{}' object cannot be interpreted as an integer",
|
||||
result.class().name
|
||||
))
|
||||
})?
|
||||
.as_bigint();
|
||||
if result.is_negative() {
|
||||
return Err(vm.new_value_error("__length_hint__() should return >= 0".to_string()));
|
||||
}
|
||||
let hint = result.to_usize().ok_or_else(|| {
|
||||
vm.new_value_error("Python int too large to convert to Rust usize".to_string())
|
||||
})?;
|
||||
Ok(Some(hint))
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
pub struct PySequenceIterator {
|
||||
@@ -124,6 +172,20 @@ impl PySequenceIterator {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
|
||||
let pos = self.position.get();
|
||||
let hint = if self.reversed {
|
||||
pos + 1
|
||||
} else {
|
||||
let len = objsequence::opt_len(&self.obj, vm).unwrap_or_else(|| {
|
||||
Err(vm.new_type_error("sequence has no __len__ method".to_string()))
|
||||
})?;
|
||||
len as isize - pos
|
||||
};
|
||||
Ok(hint)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_iter_method(obj: PyObjectRef, _vm: &VirtualMachine) -> PySequenceIterator {
|
||||
|
||||
@@ -873,6 +873,11 @@ impl PyListIterator {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
|
||||
self.list.elements.borrow().len() - self.position.get()
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -906,6 +911,11 @@ impl PyListReverseIterator {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
|
||||
self.position.get()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
|
||||
@@ -58,6 +58,15 @@ impl PyMap {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
self.iterators.iter().try_fold(0, |prev, cur| {
|
||||
let cur = objiter::length_hint(vm, cur.clone())?.unwrap_or(0);
|
||||
let max = std::cmp::max(prev, cur);
|
||||
Ok(max)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
|
||||
@@ -210,8 +210,8 @@ impl PyRange {
|
||||
}
|
||||
|
||||
#[pymethod(name = "__len__")]
|
||||
fn len(&self, _vm: &VirtualMachine) -> PyInt {
|
||||
PyInt::new(self.length())
|
||||
fn len(&self, _vm: &VirtualMachine) -> BigInt {
|
||||
self.length()
|
||||
}
|
||||
|
||||
#[pymethod(name = "__repr__")]
|
||||
|
||||
@@ -263,3 +263,33 @@ pub fn is_valid_slice_arg(
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opt_len(obj: &PyObjectRef, vm: &VirtualMachine) -> Option<PyResult<usize>> {
|
||||
vm.get_method(obj.clone(), "__len__").map(|len| {
|
||||
let len = vm.invoke(&len?, vec![])?;
|
||||
let len = len
|
||||
.payload_if_subclass::<PyInt>(vm)
|
||||
.ok_or_else(|| {
|
||||
vm.new_type_error(format!(
|
||||
"'{}' object cannot be interpreted as an integer",
|
||||
len.class().name
|
||||
))
|
||||
})?
|
||||
.as_bigint();
|
||||
if len.is_negative() {
|
||||
return Err(vm.new_value_error("__len__() should return >= 0".to_string()));
|
||||
}
|
||||
len.to_usize().ok_or_else(|| {
|
||||
vm.new_overflow_error("cannot fit __len__() result into usize".to_string())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn len(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
opt_len(obj, vm).unwrap_or_else(|| {
|
||||
Err(vm.new_type_error(format!(
|
||||
"object of type '{}' has no len()",
|
||||
obj.class().name
|
||||
)))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::cmp::Ordering;
|
||||
use std::iter;
|
||||
use std::ops::{AddAssign, SubAssign};
|
||||
use std::rc::Rc;
|
||||
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::sign::Signed;
|
||||
use num_traits::ToPrimitive;
|
||||
use num_traits::{One, Signed, ToPrimitive, Zero};
|
||||
|
||||
use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs};
|
||||
use crate::obj::objbool;
|
||||
@@ -166,11 +163,11 @@ impl PyItertoolsCount {
|
||||
) -> PyResult<PyRef<Self>> {
|
||||
let start = match start.into_option() {
|
||||
Some(int) => int.as_bigint().clone(),
|
||||
None => BigInt::from(0),
|
||||
None => BigInt::zero(),
|
||||
};
|
||||
let step = match step.into_option() {
|
||||
Some(int) => int.as_bigint().clone(),
|
||||
None => BigInt::from(1),
|
||||
None => BigInt::one(),
|
||||
};
|
||||
|
||||
PyItertoolsCount {
|
||||
@@ -183,7 +180,7 @@ impl PyItertoolsCount {
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, _vm: &VirtualMachine) -> PyResult<PyInt> {
|
||||
let result = self.cur.borrow().clone();
|
||||
AddAssign::add_assign(&mut self.cur.borrow_mut() as &mut BigInt, &self.step);
|
||||
*self.cur.borrow_mut() += &self.step;
|
||||
Ok(PyInt::new(result))
|
||||
}
|
||||
|
||||
@@ -296,16 +293,11 @@ impl PyItertoolsRepeat {
|
||||
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
if self.times.is_some() {
|
||||
match self.times.as_ref().unwrap().borrow().cmp(&BigInt::from(0)) {
|
||||
Ordering::Less | Ordering::Equal => return Err(new_stop_iteration(vm)),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
SubAssign::sub_assign(
|
||||
&mut self.times.as_ref().unwrap().borrow_mut() as &mut BigInt,
|
||||
&BigInt::from(1),
|
||||
);
|
||||
if let Some(ref times) = self.times {
|
||||
if *times.borrow() <= BigInt::zero() {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
*times.borrow_mut() -= 1;
|
||||
}
|
||||
|
||||
Ok(self.object.clone())
|
||||
@@ -315,6 +307,14 @@ impl PyItertoolsRepeat {
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__length_hint__")]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||||
match self.times {
|
||||
Some(ref times) => vm.new_int(times.borrow().clone()),
|
||||
None => vm.new_int(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(name = "starmap")]
|
||||
|
||||
@@ -16,6 +16,7 @@ mod json;
|
||||
mod keyword;
|
||||
mod marshal;
|
||||
mod math;
|
||||
mod operator;
|
||||
mod platform;
|
||||
mod pystruct;
|
||||
mod random;
|
||||
@@ -74,6 +75,7 @@ pub fn get_module_inits() -> HashMap<String, StdlibInitFunc> {
|
||||
"json".to_string() => Box::new(json::make_module),
|
||||
"marshal".to_string() => Box::new(marshal::make_module),
|
||||
"math".to_string() => Box::new(math::make_module),
|
||||
"_operator".to_string() => Box::new(operator::make_module),
|
||||
"platform".to_string() => Box::new(platform::make_module),
|
||||
"regex_crate".to_string() => Box::new(re::make_module),
|
||||
"random".to_string() => Box::new(random::make_module),
|
||||
|
||||
24
vm/src/stdlib/operator.rs
Normal file
24
vm/src/stdlib/operator.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use crate::function::OptionalArg;
|
||||
use crate::obj::{objiter, objtype};
|
||||
use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol};
|
||||
use crate::VirtualMachine;
|
||||
|
||||
fn operator_length_hint(obj: PyObjectRef, default: OptionalArg, vm: &VirtualMachine) -> PyResult {
|
||||
let default = default.unwrap_or_else(|| vm.new_int(0));
|
||||
if !objtype::isinstance(&default, &vm.ctx.types.int_type) {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"'{}' type cannot be interpreted as an integer",
|
||||
default.class().name
|
||||
)));
|
||||
}
|
||||
let hint = objiter::length_hint(vm, obj)?
|
||||
.map(|i| vm.new_int(i))
|
||||
.unwrap_or(default);
|
||||
Ok(hint)
|
||||
}
|
||||
|
||||
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
py_module!(vm, "_operator", {
|
||||
"length_hint" => vm.ctx.new_rustfunc(operator_length_hint),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user