diff --git a/vm/src/builtins/map.rs b/vm/src/builtins/map.rs index c3f2a44d6..f2102f6e9 100644 --- a/vm/src/builtins/map.rs +++ b/vm/src/builtins/map.rs @@ -1,7 +1,6 @@ use super::PyTypeRef; use crate::{ function::PosArgs, - iterator, protocol::{PyIter, PyIterReturn}, slots::{IteratorIterable, SlotConstructor, SlotIterator}, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, @@ -38,7 +37,7 @@ impl PyMap { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { self.iterators.iter().try_fold(0, |prev, cur| { - let cur = iterator::length_hint(vm, cur.as_object().clone())?.unwrap_or(0); + let cur = vm.length_hint(cur.as_object().clone())?.unwrap_or(0); let max = std::cmp::max(prev, cur); Ok(max) }) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 308d9ee12..7f9b7a53e 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -14,7 +14,6 @@ use crate::{ coroutine::Coro, exceptions::{self, ExceptionCtor}, function::FuncArgs, - iterator, protocol::{PyIter, PyIterReturn}, scope::Scope, slots::PyComparisonOp, diff --git a/vm/src/function/argument.rs b/vm/src/function/argument.rs index d60a65c13..54bf4e485 100644 --- a/vm/src/function/argument.rs +++ b/vm/src/function/argument.rs @@ -1,7 +1,6 @@ use super::IntoFuncArgs; use crate::{ builtins::iter::PySequenceIterator, - iterator, protocol::{PyIter, PyIterReturn}, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, }; @@ -58,7 +57,7 @@ impl ArgIterable { None => PySequenceIterator::new(self.iterable.clone()).into_object(vm), }; - let length_hint = iterator::length_hint(vm, iter_obj.clone())?; + let length_hint = vm.length_hint(iter_obj.clone())?; Ok(PyIterator { vm, diff --git a/vm/src/iterator.rs b/vm/src/iterator.rs index ddedbb9fb..a3367f7b6 100644 --- a/vm/src/iterator.rs +++ b/vm/src/iterator.rs @@ -3,9 +3,8 @@ */ use crate::{ - builtins::{PyBaseExceptionRef, PyInt}, protocol::{PyIter, PyIterReturn}, - IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine, + PyObjectRef, PyResult, VirtualMachine, }; pub fn try_map(vm: &VirtualMachine, iter: &PyIter, cap: usize, mut f: F) -> PyResult> @@ -25,52 +24,6 @@ where Ok(results) } -pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult> { - if let Some(len) = vm.obj_len_opt(&iter) { - match len { - Ok(len) => return Ok(Some(len)), - Err(e) => { - if !e.isinstance(&vm.ctx.exceptions.type_error) { - return Err(e); - } - } - } - } - let hint = match vm.get_method(iter, "__length_hint__") { - Some(hint) => hint?, - None => return Ok(None), - }; - let result = match vm.invoke(&hint, ()) { - Ok(res) => { - if res.is(&vm.ctx.not_implemented) { - return Ok(None); - } - res - } - Err(e) => { - return if e.isinstance(&vm.ctx.exceptions.type_error) { - Ok(None) - } else { - Err(e) - } - } - }; - let hint = result - .payload_if_subclass::(vm) - .ok_or_else(|| { - vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - result.class().name() - )) - })? - .try_to_primitive::(vm)?; - if hint.is_negative() { - Err(vm.new_value_error("__length_hint__() should return >= 0".to_owned())) - } else { - Ok(Some(hint as usize)) - } -} - // pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator { // PySequenceIterator::new_forward(obj) // } diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index 748c0a7de..dd506a786 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -13,7 +13,6 @@ mod _operator { use crate::{ builtins::{PyInt, PyIntRef, PyStrRef, PyTypeRef}, function::{ArgBytesLike, FuncArgs, KwArgs, OptionalArg}, - iterator, protocol::{PyIter, PyIterReturn}, slots::{ Callable, @@ -284,7 +283,8 @@ mod _operator { v.payload::().unwrap().try_to_primitive(vm) }) .unwrap_or(Ok(0))?; - iterator::length_hint(vm, obj).map(|v| vm.ctx.new_int(v.unwrap_or(default))) + vm.length_hint(obj) + .map(|v| vm.ctx.new_int(v.unwrap_or(default))) } // Inplace Operators diff --git a/vm/src/vm.rs b/vm/src/vm.rs index fca9bd673..7656ec518 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -1285,7 +1285,7 @@ impl VirtualMachine { .collect() } else { let iter = value.clone().get_iter(self)?; - let cap = match iterator::length_hint(self, value.clone()) { + let cap = match self.length_hint(value.clone()) { Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => return Err(e), Ok(Some(value)) => value, // Use a power of 2 as a default capacity. @@ -1333,7 +1333,7 @@ impl VirtualMachine { // TODO: put internal iterable type obj => { let iter = obj.clone().get_iter(self)?; - let cap = match iterator::length_hint(self, obj.clone()) { + let cap = match self.length_hint(obj.clone()) { Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => { return Ok(Err(e)) } @@ -1926,6 +1926,52 @@ impl VirtualMachine { }) } + pub fn length_hint(&self, iter: PyObjectRef) -> PyResult> { + if let Some(len) = self.obj_len_opt(&iter) { + match len { + Ok(len) => return Ok(Some(len)), + Err(e) => { + if !e.isinstance(&self.ctx.exceptions.type_error) { + return Err(e); + } + } + } + } + let hint = match self.get_method(iter, "__length_hint__") { + Some(hint) => hint?, + None => return Ok(None), + }; + let result = match self.invoke(&hint, ()) { + Ok(res) => { + if res.is(&self.ctx.not_implemented) { + return Ok(None); + } + res + } + Err(e) => { + return if e.isinstance(&self.ctx.exceptions.type_error) { + Ok(None) + } else { + Err(e) + } + } + }; + let hint = result + .payload_if_subclass::(self) + .ok_or_else(|| { + self.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + result.class().name() + )) + })? + .try_to_primitive::(self)?; + if hint.is_negative() { + Err(self.new_value_error("__length_hint__() should return >= 0".to_owned())) + } else { + Ok(Some(hint as usize)) + } + } + /// Checks that the multiplication is able to be performed. On Ok returns the /// index as a usize for sequences to be able to use immediately. pub fn check_repeat_or_memory_error(&self, length: usize, n: isize) -> PyResult {