From bf04b505b1a4e43961581f7a6b9be626122151b4 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 19 Sep 2021 20:49:13 +0200 Subject: [PATCH] Refactor positional iterator with general logic --- vm/src/builtins/bytearray.rs | 55 +++++++------ vm/src/builtins/bytes.rs | 50 ++++++------ vm/src/builtins/iter.rs | 154 +++++++++++++++++++++++------------ vm/src/builtins/tuple.rs | 69 +++++----------- 4 files changed, 178 insertions(+), 150 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 694cdadf9..77c664946 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -717,8 +717,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - position: AtomicCell::new(0), - bytearray: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -731,8 +730,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - position: AtomicCell, - bytearray: PyByteArrayRef, + internal: PositionIterInternal, } impl PyValue for PyByteArrayIterator { @@ -743,36 +741,45 @@ impl PyValue for PyByteArrayIterator { #[pyimpl(with(SlotIterator))] impl PyByteArrayIterator { + #[pymethod(magic)] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || { + Ok(self + .internal + .obj + .read() + .payload::() + .unwrap() + .len()) + }, + vm, + ) + } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.bytearray.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ])) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) { - PyIterReturn::Return(ret.into_pyobject(vm)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let bytearray = zelf.internal.obj.read(); + let bytearray = bytearray.payload::().unwrap(); + let buf = bytearray.borrow_buf(); + buf.get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|&x| vm.ctx.new_int(x)) + }, + vm, + ) } } diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index dfc7e76dc..a94da0a17 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -17,7 +17,6 @@ use crate::{ PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut}; use std::mem::size_of; use std::ops::Deref; @@ -574,8 +573,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - position: AtomicCell::new(0), - bytes: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -584,8 +582,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - position: AtomicCell, - bytes: PyBytesRef, + internal: PositionIterInternal, } impl PyValue for PyBytesIterator { @@ -596,37 +593,40 @@ impl PyValue for PyBytesIterator { #[pyimpl(with(SlotIterator))] impl PyBytesIterator { + #[pymethod(magic)] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || Ok(self.internal.obj.read().payload::().unwrap().len()), + vm, + ) + } + #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.bytes.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ])) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytes.as_bytes().get(pos) { - PyIterReturn::Return(vm.ctx.new_int(ret)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let bytes = zelf.internal.obj.read(); + let bytes = bytes.payload::().unwrap(); + bytes + .as_bytes() + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|&x| vm.ctx.new_int(x)) + }, + vm, + ) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index efb0756b2..3a578106f 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -12,6 +12,88 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; +#[derive(Debug)] +pub struct PositionIterInternal { + pub position: AtomicCell, + /// object or PyNone if exhausted + pub obj: PyRwLock, +} + +impl PositionIterInternal { + pub fn new(obj: PyObjectRef) -> Self { + Self { + position: AtomicCell::new(0), + obj: PyRwLock::new(obj), + } + } + + pub fn is_active(&self, vm: &VirtualMachine) -> bool { + !vm.is_none(&self.obj.read()) + } + + pub fn set_state(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if self.is_active(vm) { + if let Some(i) = state.payload::() { + let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); + self.position.store(i); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } else { + Ok(()) + } + } + + pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if self.is_active(vm) { + vm.ctx.new_tuple(vec![ + func, + vm.ctx.new_tuple(vec![self.obj.read().clone()]), + vm.ctx.new_int(self.position.load()), + ]) + } else { + vm.ctx + .new_tuple(vec![func, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]) + } + } + + pub fn next(&self, f: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce(usize) -> PyResult, + { + if self.is_active(vm) { + let pos = self.position.fetch_add(1); + match f(pos) { + Err(ref e) + if e.isinstance(&vm.ctx.exceptions.index_error) + || e.isinstance(&vm.ctx.exceptions.stop_iteration) => + { + *self.obj.write() = vm.ctx.none(); + Err(vm.new_stop_iteration()) + } + ret => ret, + } + } else { + Err(vm.new_stop_iteration()) + } + } + + pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce() -> PyResult, + { + let len = if self.is_active(vm) { + let pos = self.position.load(); + let obj_len = f()?; + obj_len.saturating_sub(pos) + } else { + 0 + }; + Ok(PyInt::from(len).into_object(vm)) + } +} + /// Marks status of iterator. #[derive(Debug, Clone, Copy)] pub enum IterStatus { @@ -24,9 +106,10 @@ pub enum IterStatus { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - pub position: AtomicCell, - pub obj: PyObjectRef, - pub status: AtomicCell, + internal: PositionIterInternal, + // pub position: AtomicCell, + // pub obj: PyObjectRef, + // pub status: AtomicCell } impl PyValue for PySequenceIterator { @@ -39,75 +122,38 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - position: AtomicCell::new(0), - obj, - status: AtomicCell::new(IterStatus::Active), + internal: PositionIterInternal::new(obj), } } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - match self.status.load() { - IterStatus::Active => { - let pos = self.position.load(); - // return NotImplemented if no length is around. - vm.obj_len(&self.obj) - .map_or(vm.ctx.not_implemented(), |len| { - PyInt::from(len.saturating_sub(pos)).into_object(vm) - }) - } - IterStatus::Exhausted => PyInt::from(0).into_object(vm), - } + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || { + vm.obj_len(&self.internal.obj.read()) + .map_err(|_| vm.new_not_implemented_error("".to_owned())) + }, + vm, + ) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - IterStatus::Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - IterStatus::Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.obj.clone()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let IterStatus::Exhausted = self.status.load() { - return Ok(()); - } - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - match zelf.obj.get_item(pos, vm) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - ret => ret.map(PyIterReturn::Return), - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal + .next(|pos| zelf.internal.obj.read().get_item(pos, vm), vm) } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index f733f222a..160f01d47 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -310,9 +310,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(Active), - tuple: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -321,9 +319,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - position: AtomicCell, - status: AtomicCell, - tuple: PyTupleRef, + internal: PositionIterInternal, } impl PyValue for PyTupleIterator { @@ -335,61 +331,40 @@ impl PyValue for PyTupleIterator { #[pyimpl(with(SlotIterator))] impl PyTupleIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.tuple.len().saturating_sub(self.position.load()), - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || Ok(self.internal.obj.read().payload::().unwrap().len()), + vm, + ) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - // Else, set to min of (pos, tuple_size). - if let Some(i) = state.payload::() { - let position = std::cmp::min( - int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0), - self.tuple.len(), - ); - self.position.store(position); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.tuple.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + Ok(self.internal.reduce(iter, vm)) } } impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { - fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - if let Some(obj) = zelf.tuple.as_slice().get(pos) { - Ok(PyIterReturn::Return(obj.clone())) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let tuple = zelf.internal.obj.read(); + let tuple = tuple.payload::().unwrap(); + tuple + .as_slice() + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) } }