[VM] Object pickling implementation for product object (python itertools) (#5089)

* Implemented __reduce__, __setstate__ in product object
This commit is contained in:
Amuthan Mannar
2023-10-12 16:42:33 +05:30
committed by GitHub
parent 830389f62c
commit 9241e2e5d5
2 changed files with 91 additions and 9 deletions

View File

@@ -208,7 +208,7 @@ class TestBasicOps(unittest.TestCase):
it = chain()
it.__setstate__((iter(['abc', 'def']), iter(['ghi'])))
self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f'])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_combinations(self):
@@ -1165,8 +1165,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_product_pickling(self):
# check copy, deepcopy, pickle
for args, result in [
@@ -2297,7 +2296,7 @@ class RegressionTests(unittest.TestCase):
class SubclassWithKwargsTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_keywords_in_subclass(self):

View File

@@ -2,6 +2,7 @@ pub(crate) use decl::make_module;
#[pymodule(name = "itertools")]
mod decl {
use crate::stdlib::itertools::decl::int::get_value;
use crate::{
builtins::{
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
@@ -110,7 +111,9 @@ mod decl {
Ok(())
}
}
impl SelfIter for PyItertoolsChain {}
impl IterNext for PyItertoolsChain {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let Some(source) = zelf.source.read().clone() else {
@@ -201,6 +204,7 @@ mod decl {
}
impl SelfIter for PyItertoolsCompress {}
impl IterNext for PyItertoolsCompress {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
loop {
@@ -268,7 +272,9 @@ mod decl {
(zelf.class().to_owned(), (zelf.cur.read().clone(),))
}
}
impl SelfIter for PyItertoolsCount {}
impl IterNext for PyItertoolsCount {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut cur = zelf.cur.write();
@@ -316,7 +322,9 @@ mod decl {
#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
impl PyItertoolsCycle {}
impl SelfIter for PyItertoolsCycle {}
impl IterNext for PyItertoolsCycle {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? {
@@ -401,6 +409,7 @@ mod decl {
}
impl SelfIter for PyItertoolsRepeat {}
impl IterNext for PyItertoolsRepeat {
fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if let Some(ref times) = zelf.times {
@@ -466,7 +475,9 @@ mod decl {
)
}
}
impl SelfIter for PyItertoolsStarmap {}
impl IterNext for PyItertoolsStarmap {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let obj = zelf.iterable.next(vm)?;
@@ -537,7 +548,9 @@ mod decl {
Ok(())
}
}
impl SelfIter for PyItertoolsTakewhile {}
impl IterNext for PyItertoolsTakewhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.stop_flag.load() {
@@ -618,7 +631,9 @@ mod decl {
Ok(())
}
}
impl SelfIter for PyItertoolsDropwhile {}
impl IterNext for PyItertoolsDropwhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
@@ -629,7 +644,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
let pred = predicate.clone();
@@ -737,7 +752,9 @@ mod decl {
Ok(PyIterReturn::Return((new_value, new_key)))
}
}
impl SelfIter for PyItertoolsGroupBy {}
impl IterNext for PyItertoolsGroupBy {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut state = zelf.state.lock();
@@ -753,7 +770,7 @@ mod decl {
let (value, new_key) = match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
if !vm.bool_eq(&new_key, &old_key)? {
@@ -764,7 +781,7 @@ mod decl {
match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
}
};
@@ -797,7 +814,9 @@ mod decl {
#[pyclass(with(IterNext, Iterable))]
impl PyItertoolsGrouper {}
impl SelfIter for PyItertoolsGrouper {}
impl IterNext for PyItertoolsGrouper {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old_key = {
@@ -960,6 +979,7 @@ mod decl {
}
impl SelfIter for PyItertoolsIslice {}
impl IterNext for PyItertoolsIslice {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
while zelf.cur.load() < zelf.next.load() {
@@ -1033,7 +1053,9 @@ mod decl {
)
}
}
impl SelfIter for PyItertoolsFilterFalse {}
impl IterNext for PyItertoolsFilterFalse {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
@@ -1142,6 +1164,7 @@ mod decl {
}
impl SelfIter for PyItertoolsAccumulate {}
impl IterNext for PyItertoolsAccumulate {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let iterable = &zelf.iterable;
@@ -1153,7 +1176,7 @@ mod decl {
None => match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
},
Some(obj) => obj.clone(),
@@ -1162,7 +1185,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
match &zelf.binop {
@@ -1348,7 +1371,60 @@ mod decl {
self.cur.store(idxs.len() - 1);
}
}
#[pymethod(magic)]
fn setstate(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.len() != zelf.pools.len() {
let msg = "Invalid number of arguments".to_string();
return Err(vm.new_type_error(msg));
}
let mut idxs: PyRwLockWriteGuard<'_, Vec<usize>> = zelf.idxs.write();
idxs.clear();
for s in 0..args.len() {
let index = get_value(state.get(s).unwrap()).to_usize().unwrap();
let pool_size = zelf.pools.get(s).unwrap().len();
if pool_size == 0 {
zelf.stop.store(true);
return Ok(());
}
if index >= pool_size {
idxs.push(pool_size - 1);
} else {
idxs.push(index);
}
}
zelf.stop.store(false);
Ok(())
}
#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
let class = zelf.class().to_owned();
if zelf.stop.load() {
return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),)));
}
let mut pools: Vec<PyObjectRef> = Vec::new();
for element in zelf.pools.iter() {
pools.push(element.clone().into_pytuple(vm).into());
}
let mut indices: Vec<PyObjectRef> = Vec::new();
for item in &zelf.idxs.read()[..] {
indices.push(vm.new_pyobj(*item));
}
vm.new_tuple((
class,
pools.clone().into_pytuple(vm),
indices.into_pytuple(vm),
))
}
}
impl SelfIter for PyItertoolsProduct {}
impl IterNext for PyItertoolsProduct {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
@@ -1563,6 +1639,7 @@ mod decl {
impl PyItertoolsCombinationsWithReplacement {}
impl SelfIter for PyItertoolsCombinationsWithReplacement {}
impl IterNext for PyItertoolsCombinationsWithReplacement {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
@@ -1679,7 +1756,9 @@ mod decl {
))
}
}
impl SelfIter for PyItertoolsPermutations {}
impl IterNext for PyItertoolsPermutations {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
@@ -1802,7 +1881,9 @@ mod decl {
Ok(())
}
}
impl SelfIter for PyItertoolsZipLongest {}
impl IterNext for PyItertoolsZipLongest {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.iterators.is_empty() {
@@ -1851,7 +1932,9 @@ mod decl {
#[pyclass(with(IterNext, Iterable, Constructor))]
impl PyItertoolsPairwise {}
impl SelfIter for PyItertoolsPairwise {}
impl IterNext for PyItertoolsPairwise {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old = match zelf.old.read().clone() {