clean up vm::iterator::try_map

This commit is contained in:
Jeong YunWon
2021-10-04 03:26:06 +09:00
parent c039e107a8
commit 0576b2bf62
3 changed files with 30 additions and 43 deletions

View File

@@ -1,22 +0,0 @@
/*
* utilities to support iteration.
*/
use crate::{protocol::PyIter, PyObjectRef, PyResult, VirtualMachine};
pub fn try_map<F, R>(vm: &VirtualMachine, iter: &PyIter, cap: usize, mut f: F) -> PyResult<Vec<R>>
where
F: FnMut(PyObjectRef) -> PyResult<R>,
{
// TODO: fix extend to do this check (?), see test_extend in Lib/test/list_tests.py,
// https://github.com/python/cpython/blob/v3.9.0/Objects/listobject.c#L922-L928
if cap >= isize::max_value() as usize {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(cap);
for element in iter.iter_without_hint(vm)? {
results.push(f(element?)?);
}
results.shrink_to_fit();
Ok(results)
}

View File

@@ -59,7 +59,6 @@ pub mod frame;
mod frozen;
pub mod function;
pub mod import;
pub mod iterator;
pub mod protocol;
pub mod py_io;
pub mod py_serde;

View File

@@ -21,8 +21,8 @@ use crate::{
frame::{ExecutionResult, Frame, FrameRef},
frozen,
function::{FuncArgs, IntoFuncArgs},
import, iterator,
protocol::PyIterReturn,
import,
protocol::{PyIterIter, PyIterReturn},
scope::Scope,
signal::NSIG,
slots::PyComparisonOp,
@@ -1284,14 +1284,7 @@ impl VirtualMachine {
.map(|obj| func(obj.clone()))
.collect()
} else {
let iter = value.clone().get_iter(self)?;
let cap = match self.length_hint(value.clone()) {
Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => return Err(e),
Ok(Some(value)) => value,
// Use a power of 2 as a default capacity.
_ => 4,
};
iterator::try_map(self, &iter, cap, |obj| func(obj))
self.map_pyiter(value, |obj| func(obj))
}
}
@@ -1332,20 +1325,37 @@ impl VirtualMachine {
ref t @ PyTuple => Ok(t.as_slice().iter().cloned().map(f).collect()),
// TODO: put internal iterable type
obj => {
let iter = obj.clone().get_iter(self)?;
let cap = match self.length_hint(obj.clone()) {
Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => {
return Ok(Err(e))
}
Ok(Some(value)) => value,
// Use a power of 2 as a default capacity.
_ => 4,
};
Ok(iterator::try_map(self, &iter, cap, f))
Ok(self.map_pyiter(obj, f))
}
})
}
fn map_pyiter<F, R>(&self, value: &PyObjectRef, mut f: F) -> PyResult<Vec<R>>
where
F: FnMut(PyObjectRef) -> PyResult<R>,
{
let iter = value.clone().get_iter(self)?;
let cap = match self.length_hint(value.clone()) {
Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => return Err(e),
Ok(Some(value)) => Some(value),
// Use a power of 2 as a default capacity.
_ => None,
};
// TODO: fix extend to do this check (?), see test_extend in Lib/test/list_tests.py,
// https://github.com/python/cpython/blob/v3.9.0/Objects/listobject.c#L922-L928
if let Some(cap) = cap {
if cap >= isize::max_value() as usize {
return Ok(Vec::new());
}
}
let mut results = PyIterIter::new(self, iter.as_object(), cap)
.map(|element| f(element?))
.collect::<PyResult<Vec<_>>>()?;
results.shrink_to_fit();
Ok(results)
}
// get_attribute should be used for full attribute access (usually from user code).
#[cfg_attr(feature = "flame-it", flame("VirtualMachine"))]
pub fn get_attribute<T>(&self, obj: PyObjectRef, attr_name: T) -> PyResult