Merge pull request #1666 from RustPython/coolreader18/length_hint

Add __length_hint__ support for iterators
This commit is contained in:
Noah
2020-01-10 12:25:36 -06:00
committed by GitHub
13 changed files with 194 additions and 40 deletions

View File

@@ -254,3 +254,16 @@ assert c == {'bla', 'c', 'd', 'f'}
assert not {}.__ne__({})
assert {}.__ne__({'a':'b'})
assert {}.__ne__(1) == NotImplemented
it = iter({0: 1, 2: 3, 4:5, 6:7})
assert it.__length_hint__() == 4
next(it)
assert it.__length_hint__() == 3
next(it)
assert it.__length_hint__() == 2
next(it)
assert it.__length_hint__() == 1
next(it)
assert it.__length_hint__() == 0
assert_raises(StopIteration, next, it)
assert it.__length_hint__() == 0

View File

@@ -22,6 +22,7 @@ use crate::obj::objdict::PyDictRef;
use crate::obj::objfunction::PyFunctionRef;
use crate::obj::objint::{self, PyIntRef};
use crate::obj::objiter;
use crate::obj::objsequence;
use crate::obj::objstr::{PyString, PyStringRef};
use crate::obj::objtype::{self, PyClassRef};
use crate::pyhash;
@@ -384,11 +385,8 @@ fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
objiter::get_iter(vm, &iter_target)
}
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let method = vm.get_method_or_type_error(obj.clone(), "__len__", || {
format!("object of type '{}' has no len()", obj.class().name)
})?;
vm.invoke(&method, PyFuncArgs::default())
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
objsequence::len(&obj, vm)
}
fn builtin_locals(vm: &VirtualMachine) -> PyDictRef {

View File

@@ -204,26 +204,26 @@ impl<T: Clone> Dict<T> {
}
pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(&PyObjectRef, &T)> {
while *position < self.entries.len() {
if let Some(DictEntry { key, value, .. }) = &self.entries[*position] {
*position += 1;
return Some((key, value));
}
self.entries[*position..].iter().find_map(|entry| {
*position += 1;
}
None
entry
.as_ref()
.map(|DictEntry { key, value, .. }| (key, value))
})
}
pub fn len_from_entry_index(&self, position: EntryIndex) -> usize {
self.entries[position..].iter().flatten().count()
}
pub fn has_changed_size(&self, position: &DictSize) -> bool {
position.size != self.size || self.entries.len() != position.entries_size
}
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = PyObjectRef> + 'a> {
Box::new(
self.entries
.iter()
.filter_map(|v| v.as_ref().map(|v| v.key.clone())),
)
pub fn keys<'a>(&'a self) -> impl Iterator<Item = PyObjectRef> + 'a {
self.entries
.iter()
.filter_map(|v| v.as_ref().map(|v| v.key.clone()))
}
/// Lookup the index for the given key.

View File

@@ -250,8 +250,6 @@ pub fn write_exception_inner<W: Write>(
write_traceback_entry(output, traceback)?;
tb = traceback.next.as_ref();
}
} else {
writeln!(output, "No traceback set on exception")?;
}
let varargs = exc.args();

View File

@@ -545,6 +545,14 @@ macro_rules! dict_iterator {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
self.dict
.entries
.borrow()
.len_from_entry_index(self.position.get())
}
}
impl PyValue for $iter_name {

View File

@@ -2,8 +2,11 @@
* Various types to support iteration.
*/
use num_traits::{Signed, ToPrimitive};
use std::cell::Cell;
use super::objint::PyInt;
use super::objsequence;
use super::objtype::{self, PyClassRef};
use crate::exceptions::PyBaseExceptionRef;
use crate::pyobject::{
@@ -61,10 +64,12 @@ pub fn get_next_object(
/* Retrieve all elements from an iterator */
pub fn get_all<T: TryFromObject>(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult<Vec<T>> {
let mut elements = vec![];
let cap = length_hint(vm, iter_obj.clone())?.unwrap_or(0);
let mut elements = Vec::with_capacity(cap);
while let Some(element) = get_next_object(vm, iter_obj)? {
elements.push(T::try_from_object(vm, element)?);
}
elements.shrink_to_fit();
Ok(elements)
}
@@ -83,6 +88,49 @@ pub fn stop_iter_value(vm: &VirtualMachine, exc: &PyBaseExceptionRef) -> PyResul
Ok(val)
}
pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult<Option<usize>> {
if let Some(len) = objsequence::opt_len(&iter, vm) {
match len {
Ok(len) => return Ok(Some(len)),
Err(e) => {
if !objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
return Err(e);
}
}
}
}
let hint = match vm.get_method(iter, "__length_hint__") {
Some(hint) => hint?,
None => return Ok(None),
};
let result = match vm.invoke(&hint, vec![]) {
Ok(res) => res,
Err(e) => {
if objtype::isinstance(&e, &vm.ctx.exceptions.type_error) {
return Ok(None);
} else {
return Err(e);
}
}
};
let result = result
.payload_if_subclass::<PyInt>(vm)
.ok_or_else(|| {
vm.new_type_error(format!(
"'{}' object cannot be interpreted as an integer",
result.class().name
))
})?
.as_bigint();
if result.is_negative() {
return Err(vm.new_value_error("__length_hint__() should return >= 0".to_string()));
}
let hint = result.to_usize().ok_or_else(|| {
vm.new_value_error("Python int too large to convert to Rust usize".to_string())
})?;
Ok(Some(hint))
}
#[pyclass]
#[derive(Debug)]
pub struct PySequenceIterator {
@@ -124,6 +172,20 @@ impl PySequenceIterator {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
let pos = self.position.get();
let hint = if self.reversed {
pos + 1
} else {
let len = objsequence::opt_len(&self.obj, vm).unwrap_or_else(|| {
Err(vm.new_type_error("sequence has no __len__ method".to_string()))
})?;
len as isize - pos
};
Ok(hint)
}
}
pub fn seq_iter_method(obj: PyObjectRef, _vm: &VirtualMachine) -> PySequenceIterator {

View File

@@ -873,6 +873,11 @@ impl PyListIterator {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
self.list.elements.borrow().len() - self.position.get()
}
}
#[pyclass]
@@ -906,6 +911,11 @@ impl PyListReverseIterator {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, _vm: &VirtualMachine) -> usize {
self.position.get()
}
}
pub fn init(context: &PyContext) {

View File

@@ -58,6 +58,15 @@ impl PyMap {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<usize> {
self.iterators.iter().try_fold(0, |prev, cur| {
let cur = objiter::length_hint(vm, cur.clone())?.unwrap_or(0);
let max = std::cmp::max(prev, cur);
Ok(max)
})
}
}
pub fn init(context: &PyContext) {

View File

@@ -210,8 +210,8 @@ impl PyRange {
}
#[pymethod(name = "__len__")]
fn len(&self, _vm: &VirtualMachine) -> PyInt {
PyInt::new(self.length())
fn len(&self, _vm: &VirtualMachine) -> BigInt {
self.length()
}
#[pymethod(name = "__repr__")]

View File

@@ -263,3 +263,33 @@ pub fn is_valid_slice_arg(
Ok(None)
}
}
pub fn opt_len(obj: &PyObjectRef, vm: &VirtualMachine) -> Option<PyResult<usize>> {
vm.get_method(obj.clone(), "__len__").map(|len| {
let len = vm.invoke(&len?, vec![])?;
let len = len
.payload_if_subclass::<PyInt>(vm)
.ok_or_else(|| {
vm.new_type_error(format!(
"'{}' object cannot be interpreted as an integer",
len.class().name
))
})?
.as_bigint();
if len.is_negative() {
return Err(vm.new_value_error("__len__() should return >= 0".to_string()));
}
len.to_usize().ok_or_else(|| {
vm.new_overflow_error("cannot fit __len__() result into usize".to_string())
})
})
}
pub fn len(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
opt_len(obj, vm).unwrap_or_else(|| {
Err(vm.new_type_error(format!(
"object of type '{}' has no len()",
obj.class().name
)))
})
}

View File

@@ -1,12 +1,9 @@
use std::cell::{Cell, RefCell};
use std::cmp::Ordering;
use std::iter;
use std::ops::{AddAssign, SubAssign};
use std::rc::Rc;
use num_bigint::BigInt;
use num_traits::sign::Signed;
use num_traits::ToPrimitive;
use num_traits::{One, Signed, ToPrimitive, Zero};
use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs};
use crate::obj::objbool;
@@ -166,11 +163,11 @@ impl PyItertoolsCount {
) -> PyResult<PyRef<Self>> {
let start = match start.into_option() {
Some(int) => int.as_bigint().clone(),
None => BigInt::from(0),
None => BigInt::zero(),
};
let step = match step.into_option() {
Some(int) => int.as_bigint().clone(),
None => BigInt::from(1),
None => BigInt::one(),
};
PyItertoolsCount {
@@ -183,7 +180,7 @@ impl PyItertoolsCount {
#[pymethod(name = "__next__")]
fn next(&self, _vm: &VirtualMachine) -> PyResult<PyInt> {
let result = self.cur.borrow().clone();
AddAssign::add_assign(&mut self.cur.borrow_mut() as &mut BigInt, &self.step);
*self.cur.borrow_mut() += &self.step;
Ok(PyInt::new(result))
}
@@ -296,16 +293,11 @@ impl PyItertoolsRepeat {
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.times.is_some() {
match self.times.as_ref().unwrap().borrow().cmp(&BigInt::from(0)) {
Ordering::Less | Ordering::Equal => return Err(new_stop_iteration(vm)),
_ => (),
};
SubAssign::sub_assign(
&mut self.times.as_ref().unwrap().borrow_mut() as &mut BigInt,
&BigInt::from(1),
);
if let Some(ref times) = self.times {
if *times.borrow() <= BigInt::zero() {
return Err(new_stop_iteration(vm));
}
*times.borrow_mut() -= 1;
}
Ok(self.object.clone())
@@ -315,6 +307,14 @@ impl PyItertoolsRepeat {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__length_hint__")]
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
match self.times {
Some(ref times) => vm.new_int(times.borrow().clone()),
None => vm.new_int(0),
}
}
}
#[pyclass(name = "starmap")]

View File

@@ -16,6 +16,7 @@ mod json;
mod keyword;
mod marshal;
mod math;
mod operator;
mod platform;
mod pystruct;
mod random;
@@ -74,6 +75,7 @@ pub fn get_module_inits() -> HashMap<String, StdlibInitFunc> {
"json".to_string() => Box::new(json::make_module),
"marshal".to_string() => Box::new(marshal::make_module),
"math".to_string() => Box::new(math::make_module),
"_operator".to_string() => Box::new(operator::make_module),
"platform".to_string() => Box::new(platform::make_module),
"regex_crate".to_string() => Box::new(re::make_module),
"random".to_string() => Box::new(random::make_module),

24
vm/src/stdlib/operator.rs Normal file
View File

@@ -0,0 +1,24 @@
use crate::function::OptionalArg;
use crate::obj::{objiter, objtype};
use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol};
use crate::VirtualMachine;
fn operator_length_hint(obj: PyObjectRef, default: OptionalArg, vm: &VirtualMachine) -> PyResult {
let default = default.unwrap_or_else(|| vm.new_int(0));
if !objtype::isinstance(&default, &vm.ctx.types.int_type) {
return Err(vm.new_type_error(format!(
"'{}' type cannot be interpreted as an integer",
default.class().name
)));
}
let hint = objiter::length_hint(vm, obj)?
.map(|i| vm.new_int(i))
.unwrap_or(default);
Ok(hint)
}
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
py_module!(vm, "_operator", {
"length_hint" => vm.ctx.new_rustfunc(operator_length_hint),
})
}