From 6b537f15e5cb358812a27a9c82a878730b7f5c26 Mon Sep 17 00:00:00 2001 From: coolreader18 <33094578+coolreader18@users.noreply.github.com> Date: Sun, 29 Dec 2019 11:39:25 -0600 Subject: [PATCH 1/5] Add objsequence::{opt_,}len --- vm/src/builtins.rs | 8 +++----- vm/src/exceptions.rs | 2 -- vm/src/obj/objsequence.rs | 30 ++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 73f237a71..d8ef813ed 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -22,6 +22,7 @@ use crate::obj::objdict::PyDictRef; use crate::obj::objfunction::PyFunctionRef; use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; +use crate::obj::objsequence; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyhash; @@ -384,11 +385,8 @@ fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult { objiter::get_iter(vm, &iter_target) } -fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let method = vm.get_method_or_type_error(obj.clone(), "__len__", || { - format!("object of type '{}' has no len()", obj.class().name) - })?; - vm.invoke(&method, PyFuncArgs::default()) +fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + objsequence::len(&obj, vm) } fn builtin_locals(vm: &VirtualMachine) -> PyDictRef { diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 5a7cc1d8d..506a7f55d 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -250,8 +250,6 @@ pub fn write_exception_inner( write_traceback_entry(output, traceback)?; tb = traceback.next.as_ref(); } - } else { - writeln!(output, "No traceback set on exception")?; } let varargs = exc.args(); diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index 4b1590f20..aecfea219 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -263,3 +263,33 @@ pub fn is_valid_slice_arg( Ok(None) } } + +pub fn opt_len(obj: &PyObjectRef, vm: &VirtualMachine) -> Option> { + vm.get_method(obj.clone(), "__len__").map(|len| { + let len = vm.invoke(&len?, vec![])?; + let len = len + .payload_if_subclass::(vm) + .ok_or_else(|| { + vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + len.class().name + )) + })? + .as_bigint(); + if len.is_negative() { + return Err(vm.new_value_error("__len__() should return >= 0".to_string())); + } + len.to_usize().ok_or_else(|| { + vm.new_overflow_error("cannot fit __len__() result into usize".to_string()) + }) + }) +} + +pub fn len(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + opt_len(obj, vm).unwrap_or_else(|| { + Err(vm.new_type_error(format!( + "object of type '{}' has no len()", + obj.class().name + ))) + }) +} From fc3ac169d029142b6d133d69999a6cab62c753fe Mon Sep 17 00:00:00 2001 From: coolreader18 <33094578+coolreader18@users.noreply.github.com> Date: Sun, 29 Dec 2019 11:39:50 -0600 Subject: [PATCH 2/5] Add length_hint functionality --- vm/src/obj/objiter.rs | 50 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 4fdd616e7..2a5282e32 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -2,8 +2,11 @@ * Various types to support iteration. */ +use num_traits::{Signed, ToPrimitive}; use std::cell::Cell; +use super::objint::PyInt; +use super::objsequence; use super::objtype::{self, PyClassRef}; use crate::exceptions::PyBaseExceptionRef; use crate::pyobject::{ @@ -61,10 +64,12 @@ pub fn get_next_object( /* Retrieve all elements from an iterator */ pub fn get_all(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult> { - let mut elements = vec![]; + let cap = length_hint(vm, iter_obj.clone())?.unwrap_or(0); + let mut elements = Vec::with_capacity(cap); while let Some(element) = get_next_object(vm, iter_obj)? { elements.push(T::try_from_object(vm, element)?); } + elements.shrink_to_fit(); Ok(elements) } @@ -83,6 +88,49 @@ pub fn stop_iter_value(vm: &VirtualMachine, exc: &PyBaseExceptionRef) -> PyResul Ok(val) } +pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult> { + if let Some(len) = objsequence::opt_len(&iter, vm) { + match len { + Ok(len) => return Ok(Some(len)), + Err(e) => { + if !objtype::isinstance(&e, &vm.ctx.exceptions.type_error) { + return Err(e); + } + } + } + } + let hint = match vm.get_method(iter, "__length_hint__") { + Some(hint) => hint?, + None => return Ok(None), + }; + let result = match vm.invoke(&hint, vec![]) { + Ok(res) => res, + Err(e) => { + if objtype::isinstance(&e, &vm.ctx.exceptions.type_error) { + return Ok(None); + } else { + return Err(e); + } + } + }; + let result = result + .payload_if_subclass::(vm) + .ok_or_else(|| { + vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + result.class().name + )) + })? + .as_bigint(); + if result.is_negative() { + return Err(vm.new_value_error("__length_hint__() should return >= 0".to_string())); + } + let hint = result.to_usize().ok_or_else(|| { + vm.new_value_error("Python int too large to convert to Rust usize".to_string()) + })?; + Ok(Some(hint)) +} + #[pyclass] #[derive(Debug)] pub struct PySequenceIterator { From c65dc336f5fcb681200cf3deea0af9f91c460f59 Mon Sep 17 00:00:00 2001 From: coolreader18 <33094578+coolreader18@users.noreply.github.com> Date: Sun, 29 Dec 2019 11:40:05 -0600 Subject: [PATCH 3/5] Add _operator module --- vm/src/stdlib/mod.rs | 2 ++ vm/src/stdlib/operator.rs | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 vm/src/stdlib/operator.rs diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 5e21b04e7..d3dd0e732 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -16,6 +16,7 @@ mod json; mod keyword; mod marshal; mod math; +mod operator; mod platform; mod pystruct; mod random; @@ -74,6 +75,7 @@ pub fn get_module_inits() -> HashMap { "json".to_string() => Box::new(json::make_module), "marshal".to_string() => Box::new(marshal::make_module), "math".to_string() => Box::new(math::make_module), + "_operator".to_string() => Box::new(operator::make_module), "platform".to_string() => Box::new(platform::make_module), "regex_crate".to_string() => Box::new(re::make_module), "random".to_string() => Box::new(random::make_module), diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs new file mode 100644 index 000000000..b2fbc7bea --- /dev/null +++ b/vm/src/stdlib/operator.rs @@ -0,0 +1,24 @@ +use crate::function::OptionalArg; +use crate::obj::{objiter, objtype}; +use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol}; +use crate::VirtualMachine; + +fn operator_length_hint(obj: PyObjectRef, default: OptionalArg, vm: &VirtualMachine) -> PyResult { + let default = default.unwrap_or_else(|| vm.new_int(0)); + if !objtype::isinstance(&default, &vm.ctx.types.int_type) { + return Err(vm.new_type_error(format!( + "'{}' type cannot be interpreted as an integer", + default.class().name + ))); + } + let hint = objiter::length_hint(vm, obj)? + .map(|i| vm.new_int(i)) + .unwrap_or(default); + Ok(hint) +} + +pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { + py_module!(vm, "_operator", { + "length_hint" => vm.ctx.new_rustfunc(operator_length_hint), + }) +} From 098393d4d6e4db91d31868800ef7431535825971 Mon Sep 17 00:00:00 2001 From: coolreader18 <33094578+coolreader18@users.noreply.github.com> Date: Thu, 2 Jan 2020 23:41:35 -0600 Subject: [PATCH 4/5] Add __length_hint__ to a bunch of iterators --- vm/src/obj/objiter.rs | 14 ++++++++++++++ vm/src/obj/objlist.rs | 10 ++++++++++ vm/src/obj/objmap.rs | 9 +++++++++ vm/src/obj/objrange.rs | 4 ++-- vm/src/stdlib/itertools.rs | 34 +++++++++++++++++----------------- 5 files changed, 52 insertions(+), 19 deletions(-) diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 2a5282e32..408cb55d6 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -172,6 +172,20 @@ impl PySequenceIterator { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + let pos = self.position.get(); + let hint = if self.reversed { + pos + 1 + } else { + let len = objsequence::opt_len(&self.obj, vm).unwrap_or_else(|| { + Err(vm.new_type_error("sequence has no __len__ method".to_string())) + })?; + len as isize - pos + }; + Ok(hint) + } } pub fn seq_iter_method(obj: PyObjectRef, _vm: &VirtualMachine) -> PySequenceIterator { diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index d4bd8e02c..85672ac24 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -873,6 +873,11 @@ impl PyListIterator { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, _vm: &VirtualMachine) -> usize { + self.list.elements.borrow().len() - self.position.get() + } } #[pyclass] @@ -906,6 +911,11 @@ impl PyListReverseIterator { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, _vm: &VirtualMachine) -> usize { + self.position.get() + } } pub fn init(context: &PyContext) { diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs index fde6b95cc..1fd272afb 100644 --- a/vm/src/obj/objmap.rs +++ b/vm/src/obj/objmap.rs @@ -58,6 +58,15 @@ impl PyMap { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.iterators.iter().try_fold(0, |prev, cur| { + let cur = objiter::length_hint(vm, cur.clone())?.unwrap_or(0); + let max = std::cmp::max(prev, cur); + Ok(max) + }) + } } pub fn init(context: &PyContext) { diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index b35b1701f..c3885beb9 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -210,8 +210,8 @@ impl PyRange { } #[pymethod(name = "__len__")] - fn len(&self, _vm: &VirtualMachine) -> PyInt { - PyInt::new(self.length()) + fn len(&self, _vm: &VirtualMachine) -> BigInt { + self.length() } #[pymethod(name = "__repr__")] diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 1247b38a8..63648d303 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,12 +1,9 @@ use std::cell::{Cell, RefCell}; -use std::cmp::Ordering; use std::iter; -use std::ops::{AddAssign, SubAssign}; use std::rc::Rc; use num_bigint::BigInt; -use num_traits::sign::Signed; -use num_traits::ToPrimitive; +use num_traits::{One, Signed, ToPrimitive, Zero}; use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbool; @@ -166,11 +163,11 @@ impl PyItertoolsCount { ) -> PyResult> { let start = match start.into_option() { Some(int) => int.as_bigint().clone(), - None => BigInt::from(0), + None => BigInt::zero(), }; let step = match step.into_option() { Some(int) => int.as_bigint().clone(), - None => BigInt::from(1), + None => BigInt::one(), }; PyItertoolsCount { @@ -183,7 +180,7 @@ impl PyItertoolsCount { #[pymethod(name = "__next__")] fn next(&self, _vm: &VirtualMachine) -> PyResult { let result = self.cur.borrow().clone(); - AddAssign::add_assign(&mut self.cur.borrow_mut() as &mut BigInt, &self.step); + *self.cur.borrow_mut() += &self.step; Ok(PyInt::new(result)) } @@ -296,16 +293,11 @@ impl PyItertoolsRepeat { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.times.is_some() { - match self.times.as_ref().unwrap().borrow().cmp(&BigInt::from(0)) { - Ordering::Less | Ordering::Equal => return Err(new_stop_iteration(vm)), - _ => (), - }; - - SubAssign::sub_assign( - &mut self.times.as_ref().unwrap().borrow_mut() as &mut BigInt, - &BigInt::from(1), - ); + if let Some(ref times) = self.times { + if *times.borrow() <= BigInt::zero() { + return Err(new_stop_iteration(vm)); + } + *times.borrow_mut() -= 1; } Ok(self.object.clone()) @@ -315,6 +307,14 @@ impl PyItertoolsRepeat { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + match self.times { + Some(ref times) => vm.new_int(times.borrow().clone()), + None => vm.new_int(0), + } + } } #[pyclass(name = "starmap")] From a26e5f23f86f75c9087c3b2a4ab546a56ff46ff7 Mon Sep 17 00:00:00 2001 From: coolreader18 <33094578+coolreader18@users.noreply.github.com> Date: Fri, 3 Jan 2020 00:10:30 -0600 Subject: [PATCH 5/5] Add dict iterator __length_hint__ --- tests/snippets/dict.py | 13 +++++++++++++ vm/src/dictdatatype.rs | 26 +++++++++++++------------- vm/src/obj/objdict.rs | 8 ++++++++ 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/snippets/dict.py b/tests/snippets/dict.py index 64645bbc4..2359abe37 100644 --- a/tests/snippets/dict.py +++ b/tests/snippets/dict.py @@ -254,3 +254,16 @@ assert c == {'bla', 'c', 'd', 'f'} assert not {}.__ne__({}) assert {}.__ne__({'a':'b'}) assert {}.__ne__(1) == NotImplemented + +it = iter({0: 1, 2: 3, 4:5, 6:7}) +assert it.__length_hint__() == 4 +next(it) +assert it.__length_hint__() == 3 +next(it) +assert it.__length_hint__() == 2 +next(it) +assert it.__length_hint__() == 1 +next(it) +assert it.__length_hint__() == 0 +assert_raises(StopIteration, next, it) +assert it.__length_hint__() == 0 diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 1e809f22c..2d0b93582 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -204,26 +204,26 @@ impl Dict { } pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(&PyObjectRef, &T)> { - while *position < self.entries.len() { - if let Some(DictEntry { key, value, .. }) = &self.entries[*position] { - *position += 1; - return Some((key, value)); - } + self.entries[*position..].iter().find_map(|entry| { *position += 1; - } - None + entry + .as_ref() + .map(|DictEntry { key, value, .. }| (key, value)) + }) + } + + pub fn len_from_entry_index(&self, position: EntryIndex) -> usize { + self.entries[position..].iter().flatten().count() } pub fn has_changed_size(&self, position: &DictSize) -> bool { position.size != self.size || self.entries.len() != position.entries_size } - pub fn keys<'a>(&'a self) -> Box + 'a> { - Box::new( - self.entries - .iter() - .filter_map(|v| v.as_ref().map(|v| v.key.clone())), - ) + pub fn keys<'a>(&'a self) -> impl Iterator + 'a { + self.entries + .iter() + .filter_map(|v| v.as_ref().map(|v| v.key.clone())) } /// Lookup the index for the given key. diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index eb61d7b2b..9b85bb272 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -545,6 +545,14 @@ macro_rules! dict_iterator { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, _vm: &VirtualMachine) -> usize { + self.dict + .entries + .borrow() + .len_from_entry_index(self.position.get()) + } } impl PyValue for $iter_name {