Refactor positional iterator with general logic

This commit is contained in:
Kangzhi Shi
2021-09-19 20:49:13 +02:00
parent 4e6c451b2a
commit bf04b505b1
4 changed files with 178 additions and 150 deletions

View File

@@ -717,8 +717,7 @@ impl Unhashable for PyByteArray {}
impl Iterable for PyByteArray {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyByteArrayIterator {
position: AtomicCell::new(0),
bytearray: zelf,
internal: PositionIterInternal::new(zelf.into_object()),
}
.into_object(vm))
}
@@ -731,8 +730,7 @@ impl Iterable for PyByteArray {
#[pyclass(module = false, name = "bytearray_iterator")]
#[derive(Debug)]
pub struct PyByteArrayIterator {
position: AtomicCell<usize>,
bytearray: PyByteArrayRef,
internal: PositionIterInternal,
}
impl PyValue for PyByteArrayIterator {
@@ -743,36 +741,45 @@ impl PyValue for PyByteArrayIterator {
#[pyimpl(with(SlotIterator))]
impl PyByteArrayIterator {
#[pymethod(magic)]
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
self.internal.length_hint(
|| {
Ok(self
.internal
.obj
.read()
.payload::<PyByteArray>()
.unwrap()
.len())
},
vm,
)
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.bytearray.clone().into_object()]),
vm.ctx.new_int(self.position.load()),
]))
Ok(self.internal.reduce(iter, vm))
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if let Some(i) = state.payload::<PyInt>() {
self.position
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
self.internal.set_state(state, vm)
}
}
impl IteratorIterable for PyByteArrayIterator {}
impl SlotIterator for PyByteArrayIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let pos = zelf.position.fetch_add(1);
let r = if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) {
PyIterReturn::Return(ret.into_pyobject(vm))
} else {
PyIterReturn::StopIteration(None)
};
Ok(r)
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
zelf.internal.next(
|pos| {
let bytearray = zelf.internal.obj.read();
let bytearray = bytearray.payload::<PyByteArray>().unwrap();
let buf = bytearray.borrow_buf();
buf.get(pos)
.ok_or_else(|| vm.new_stop_iteration())
.map(|&x| vm.ctx.new_int(x))
},
vm,
)
}
}

View File

@@ -17,7 +17,6 @@ use crate::{
PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine,
};
use bstr::ByteSlice;
use crossbeam_utils::atomic::AtomicCell;
use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut};
use std::mem::size_of;
use std::ops::Deref;
@@ -574,8 +573,7 @@ impl Comparable for PyBytes {
impl Iterable for PyBytes {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyBytesIterator {
position: AtomicCell::new(0),
bytes: zelf,
internal: PositionIterInternal::new(zelf.into_object()),
}
.into_object(vm))
}
@@ -584,8 +582,7 @@ impl Iterable for PyBytes {
#[pyclass(module = false, name = "bytes_iterator")]
#[derive(Debug)]
pub struct PyBytesIterator {
position: AtomicCell<usize>,
bytes: PyBytesRef,
internal: PositionIterInternal,
}
impl PyValue for PyBytesIterator {
@@ -596,37 +593,40 @@ impl PyValue for PyBytesIterator {
#[pyimpl(with(SlotIterator))]
impl PyBytesIterator {
#[pymethod(magic)]
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
self.internal.length_hint(
|| Ok(self.internal.obj.read().payload::<PyBytes>().unwrap().len()),
vm,
)
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.bytes.clone().into_object()]),
vm.ctx.new_int(self.position.load()),
]))
Ok(self.internal.reduce(iter, vm))
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if let Some(i) = state.payload::<PyInt>() {
self.position
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
self.internal.set_state(state, vm)
}
}
impl IteratorIterable for PyBytesIterator {}
impl SlotIterator for PyBytesIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let pos = zelf.position.fetch_add(1);
let r = if let Some(&ret) = zelf.bytes.as_bytes().get(pos) {
PyIterReturn::Return(vm.ctx.new_int(ret))
} else {
PyIterReturn::StopIteration(None)
};
Ok(r)
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
zelf.internal.next(
|pos| {
let bytes = zelf.internal.obj.read();
let bytes = bytes.payload::<PyBytes>().unwrap();
bytes
.as_bytes()
.get(pos)
.ok_or_else(|| vm.new_stop_iteration())
.map(|&x| vm.ctx.new_int(x))
},
vm,
)
}
}

View File

@@ -12,6 +12,88 @@ use crate::{
};
use crossbeam_utils::atomic::AtomicCell;
#[derive(Debug)]
pub struct PositionIterInternal {
pub position: AtomicCell<usize>,
/// object or PyNone if exhausted
pub obj: PyRwLock<PyObjectRef>,
}
impl PositionIterInternal {
pub fn new(obj: PyObjectRef) -> Self {
Self {
position: AtomicCell::new(0),
obj: PyRwLock::new(obj),
}
}
pub fn is_active(&self, vm: &VirtualMachine) -> bool {
!vm.is_none(&self.obj.read())
}
pub fn set_state(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if self.is_active(vm) {
if let Some(i) = state.payload::<PyInt>() {
let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0);
self.position.store(i);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
} else {
Ok(())
}
}
pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if self.is_active(vm) {
vm.ctx.new_tuple(vec![
func,
vm.ctx.new_tuple(vec![self.obj.read().clone()]),
vm.ctx.new_int(self.position.load()),
])
} else {
vm.ctx
.new_tuple(vec![func, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])])
}
}
pub fn next<F>(&self, f: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce(usize) -> PyResult,
{
if self.is_active(vm) {
let pos = self.position.fetch_add(1);
match f(pos) {
Err(ref e)
if e.isinstance(&vm.ctx.exceptions.index_error)
|| e.isinstance(&vm.ctx.exceptions.stop_iteration) =>
{
*self.obj.write() = vm.ctx.none();
Err(vm.new_stop_iteration())
}
ret => ret,
}
} else {
Err(vm.new_stop_iteration())
}
}
pub fn length_hint<F>(&self, f: F, vm: &VirtualMachine) -> PyResult
where
F: FnOnce() -> PyResult<usize>,
{
let len = if self.is_active(vm) {
let pos = self.position.load();
let obj_len = f()?;
obj_len.saturating_sub(pos)
} else {
0
};
Ok(PyInt::from(len).into_object(vm))
}
}
/// Marks status of iterator.
#[derive(Debug, Clone, Copy)]
pub enum IterStatus {
@@ -24,9 +106,10 @@ pub enum IterStatus {
#[pyclass(module = false, name = "iterator")]
#[derive(Debug)]
pub struct PySequenceIterator {
pub position: AtomicCell<usize>,
pub obj: PyObjectRef,
pub status: AtomicCell<IterStatus>,
internal: PositionIterInternal,
// pub position: AtomicCell<usize>,
// pub obj: PyObjectRef,
// pub status: AtomicCell<IterStatus>
}
impl PyValue for PySequenceIterator {
@@ -39,75 +122,38 @@ impl PyValue for PySequenceIterator {
impl PySequenceIterator {
pub fn new(obj: PyObjectRef) -> Self {
Self {
position: AtomicCell::new(0),
obj,
status: AtomicCell::new(IterStatus::Active),
internal: PositionIterInternal::new(obj),
}
}
#[pymethod(magic)]
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
match self.status.load() {
IterStatus::Active => {
let pos = self.position.load();
// return NotImplemented if no length is around.
vm.obj_len(&self.obj)
.map_or(vm.ctx.not_implemented(), |len| {
PyInt::from(len.saturating_sub(pos)).into_object(vm)
})
}
IterStatus::Exhausted => PyInt::from(0).into_object(vm),
}
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
self.internal.length_hint(
|| {
vm.obj_len(&self.internal.obj.read())
.map_err(|_| vm.new_not_implemented_error("".to_owned()))
},
vm,
)
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(match self.status.load() {
IterStatus::Exhausted => vm
.ctx
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
IterStatus::Active => vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.obj.clone()]),
vm.ctx.new_int(self.position.load()),
]),
})
Ok(self.internal.reduce(iter, vm))
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// When we're exhausted, just return.
if let IterStatus::Exhausted = self.status.load() {
return Ok(());
}
if let Some(i) = state.payload::<PyInt>() {
self.position
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
self.internal.set_state(state, vm)
}
}
impl IteratorIterable for PySequenceIterator {}
impl SlotIterator for PySequenceIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if let IterStatus::Exhausted = zelf.status.load() {
return Ok(PyIterReturn::StopIteration(None));
}
let pos = zelf.position.fetch_add(1);
match zelf.obj.get_item(pos, vm) {
Err(ref e)
if e.isinstance(&vm.ctx.exceptions.index_error)
|| e.isinstance(&vm.ctx.exceptions.stop_iteration) =>
{
zelf.status.store(IterStatus::Exhausted);
Ok(PyIterReturn::StopIteration(None))
}
ret => ret.map(PyIterReturn::Return),
}
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
zelf.internal
.next(|pos| zelf.internal.obj.read().get_item(pos, vm), vm)
}
}

View File

@@ -310,9 +310,7 @@ impl Comparable for PyTuple {
impl Iterable for PyTuple {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyTupleIterator {
position: AtomicCell::new(0),
status: AtomicCell::new(Active),
tuple: zelf,
internal: PositionIterInternal::new(zelf.into_object()),
}
.into_object(vm))
}
@@ -321,9 +319,7 @@ impl Iterable for PyTuple {
#[pyclass(module = false, name = "tuple_iterator")]
#[derive(Debug)]
pub(crate) struct PyTupleIterator {
position: AtomicCell<usize>,
status: AtomicCell<IterStatus>,
tuple: PyTupleRef,
internal: PositionIterInternal,
}
impl PyValue for PyTupleIterator {
@@ -335,61 +331,40 @@ impl PyValue for PyTupleIterator {
#[pyimpl(with(SlotIterator))]
impl PyTupleIterator {
#[pymethod(magic)]
fn length_hint(&self) -> usize {
match self.status.load() {
Active => self.tuple.len().saturating_sub(self.position.load()),
Exhausted => 0,
}
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
self.internal.length_hint(
|| Ok(self.internal.obj.read().payload::<PyTuple>().unwrap().len()),
vm,
)
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// When we're exhausted, just return.
if let Exhausted = self.status.load() {
return Ok(());
}
// Else, set to min of (pos, tuple_size).
if let Some(i) = state.payload::<PyInt>() {
let position = std::cmp::min(
int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0),
self.tuple.len(),
);
self.position.store(position);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
self.internal.set_state(state, vm)
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(match self.status.load() {
Exhausted => vm
.ctx
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
Active => vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.tuple.clone().into_object()]),
vm.ctx.new_int(self.position.load()),
]),
})
Ok(self.internal.reduce(iter, vm))
}
}
impl IteratorIterable for PyTupleIterator {}
impl SlotIterator for PyTupleIterator {
fn next(zelf: &PyRef<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if let Exhausted = zelf.status.load() {
return Ok(PyIterReturn::StopIteration(None));
}
let pos = zelf.position.fetch_add(1);
if let Some(obj) = zelf.tuple.as_slice().get(pos) {
Ok(PyIterReturn::Return(obj.clone()))
} else {
zelf.status.store(Exhausted);
Ok(PyIterReturn::StopIteration(None))
}
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
zelf.internal.next(
|pos| {
let tuple = zelf.internal.obj.read();
let tuple = tuple.payload::<PyTuple>().unwrap();
tuple
.as_slice()
.get(pos)
.ok_or_else(|| vm.new_stop_iteration())
.map(|x| x.clone())
},
vm,
)
}
}