mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Refactor positional iterator with general logic
This commit is contained in:
@@ -717,8 +717,7 @@ impl Unhashable for PyByteArray {}
|
||||
impl Iterable for PyByteArray {
|
||||
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
Ok(PyByteArrayIterator {
|
||||
position: AtomicCell::new(0),
|
||||
bytearray: zelf,
|
||||
internal: PositionIterInternal::new(zelf.into_object()),
|
||||
}
|
||||
.into_object(vm))
|
||||
}
|
||||
@@ -731,8 +730,7 @@ impl Iterable for PyByteArray {
|
||||
#[pyclass(module = false, name = "bytearray_iterator")]
|
||||
#[derive(Debug)]
|
||||
pub struct PyByteArrayIterator {
|
||||
position: AtomicCell<usize>,
|
||||
bytearray: PyByteArrayRef,
|
||||
internal: PositionIterInternal,
|
||||
}
|
||||
|
||||
impl PyValue for PyByteArrayIterator {
|
||||
@@ -743,36 +741,45 @@ impl PyValue for PyByteArrayIterator {
|
||||
|
||||
#[pyimpl(with(SlotIterator))]
|
||||
impl PyByteArrayIterator {
|
||||
#[pymethod(magic)]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.internal.length_hint(
|
||||
|| {
|
||||
Ok(self
|
||||
.internal
|
||||
.obj
|
||||
.read()
|
||||
.payload::<PyByteArray>()
|
||||
.unwrap()
|
||||
.len())
|
||||
},
|
||||
vm,
|
||||
)
|
||||
}
|
||||
#[pymethod(magic)]
|
||||
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
|
||||
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
|
||||
Ok(vm.ctx.new_tuple(vec![
|
||||
iter,
|
||||
vm.ctx.new_tuple(vec![self.bytearray.clone().into_object()]),
|
||||
vm.ctx.new_int(self.position.load()),
|
||||
]))
|
||||
Ok(self.internal.reduce(iter, vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
if let Some(i) = state.payload::<PyInt>() {
|
||||
self.position
|
||||
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(vm.new_type_error("an integer is required.".to_owned()))
|
||||
}
|
||||
self.internal.set_state(state, vm)
|
||||
}
|
||||
}
|
||||
impl IteratorIterable for PyByteArrayIterator {}
|
||||
impl SlotIterator for PyByteArrayIterator {
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
|
||||
let pos = zelf.position.fetch_add(1);
|
||||
let r = if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) {
|
||||
PyIterReturn::Return(ret.into_pyobject(vm))
|
||||
} else {
|
||||
PyIterReturn::StopIteration(None)
|
||||
};
|
||||
Ok(r)
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
zelf.internal.next(
|
||||
|pos| {
|
||||
let bytearray = zelf.internal.obj.read();
|
||||
let bytearray = bytearray.payload::<PyByteArray>().unwrap();
|
||||
let buf = bytearray.borrow_buf();
|
||||
buf.get(pos)
|
||||
.ok_or_else(|| vm.new_stop_iteration())
|
||||
.map(|&x| vm.ctx.new_int(x))
|
||||
},
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use crate::{
|
||||
PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine,
|
||||
};
|
||||
use bstr::ByteSlice;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut};
|
||||
use std::mem::size_of;
|
||||
use std::ops::Deref;
|
||||
@@ -574,8 +573,7 @@ impl Comparable for PyBytes {
|
||||
impl Iterable for PyBytes {
|
||||
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
Ok(PyBytesIterator {
|
||||
position: AtomicCell::new(0),
|
||||
bytes: zelf,
|
||||
internal: PositionIterInternal::new(zelf.into_object()),
|
||||
}
|
||||
.into_object(vm))
|
||||
}
|
||||
@@ -584,8 +582,7 @@ impl Iterable for PyBytes {
|
||||
#[pyclass(module = false, name = "bytes_iterator")]
|
||||
#[derive(Debug)]
|
||||
pub struct PyBytesIterator {
|
||||
position: AtomicCell<usize>,
|
||||
bytes: PyBytesRef,
|
||||
internal: PositionIterInternal,
|
||||
}
|
||||
|
||||
impl PyValue for PyBytesIterator {
|
||||
@@ -596,37 +593,40 @@ impl PyValue for PyBytesIterator {
|
||||
|
||||
#[pyimpl(with(SlotIterator))]
|
||||
impl PyBytesIterator {
|
||||
#[pymethod(magic)]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.internal.length_hint(
|
||||
|| Ok(self.internal.obj.read().payload::<PyBytes>().unwrap().len()),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
|
||||
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
|
||||
Ok(vm.ctx.new_tuple(vec![
|
||||
iter,
|
||||
vm.ctx.new_tuple(vec![self.bytes.clone().into_object()]),
|
||||
vm.ctx.new_int(self.position.load()),
|
||||
]))
|
||||
Ok(self.internal.reduce(iter, vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
if let Some(i) = state.payload::<PyInt>() {
|
||||
self.position
|
||||
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(vm.new_type_error("an integer is required.".to_owned()))
|
||||
}
|
||||
self.internal.set_state(state, vm)
|
||||
}
|
||||
}
|
||||
impl IteratorIterable for PyBytesIterator {}
|
||||
impl SlotIterator for PyBytesIterator {
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
|
||||
let pos = zelf.position.fetch_add(1);
|
||||
let r = if let Some(&ret) = zelf.bytes.as_bytes().get(pos) {
|
||||
PyIterReturn::Return(vm.ctx.new_int(ret))
|
||||
} else {
|
||||
PyIterReturn::StopIteration(None)
|
||||
};
|
||||
Ok(r)
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
zelf.internal.next(
|
||||
|pos| {
|
||||
let bytes = zelf.internal.obj.read();
|
||||
let bytes = bytes.payload::<PyBytes>().unwrap();
|
||||
bytes
|
||||
.as_bytes()
|
||||
.get(pos)
|
||||
.ok_or_else(|| vm.new_stop_iteration())
|
||||
.map(|&x| vm.ctx.new_int(x))
|
||||
},
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,88 @@ use crate::{
|
||||
};
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PositionIterInternal {
|
||||
pub position: AtomicCell<usize>,
|
||||
/// object or PyNone if exhausted
|
||||
pub obj: PyRwLock<PyObjectRef>,
|
||||
}
|
||||
|
||||
impl PositionIterInternal {
|
||||
pub fn new(obj: PyObjectRef) -> Self {
|
||||
Self {
|
||||
position: AtomicCell::new(0),
|
||||
obj: PyRwLock::new(obj),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_active(&self, vm: &VirtualMachine) -> bool {
|
||||
!vm.is_none(&self.obj.read())
|
||||
}
|
||||
|
||||
pub fn set_state(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
if self.is_active(vm) {
|
||||
if let Some(i) = state.payload::<PyInt>() {
|
||||
let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0);
|
||||
self.position.store(i);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(vm.new_type_error("an integer is required.".to_owned()))
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
|
||||
if self.is_active(vm) {
|
||||
vm.ctx.new_tuple(vec![
|
||||
func,
|
||||
vm.ctx.new_tuple(vec![self.obj.read().clone()]),
|
||||
vm.ctx.new_int(self.position.load()),
|
||||
])
|
||||
} else {
|
||||
vm.ctx
|
||||
.new_tuple(vec![func, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next<F>(&self, f: F, vm: &VirtualMachine) -> PyResult
|
||||
where
|
||||
F: FnOnce(usize) -> PyResult,
|
||||
{
|
||||
if self.is_active(vm) {
|
||||
let pos = self.position.fetch_add(1);
|
||||
match f(pos) {
|
||||
Err(ref e)
|
||||
if e.isinstance(&vm.ctx.exceptions.index_error)
|
||||
|| e.isinstance(&vm.ctx.exceptions.stop_iteration) =>
|
||||
{
|
||||
*self.obj.write() = vm.ctx.none();
|
||||
Err(vm.new_stop_iteration())
|
||||
}
|
||||
ret => ret,
|
||||
}
|
||||
} else {
|
||||
Err(vm.new_stop_iteration())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn length_hint<F>(&self, f: F, vm: &VirtualMachine) -> PyResult
|
||||
where
|
||||
F: FnOnce() -> PyResult<usize>,
|
||||
{
|
||||
let len = if self.is_active(vm) {
|
||||
let pos = self.position.load();
|
||||
let obj_len = f()?;
|
||||
obj_len.saturating_sub(pos)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
Ok(PyInt::from(len).into_object(vm))
|
||||
}
|
||||
}
|
||||
|
||||
/// Marks status of iterator.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum IterStatus {
|
||||
@@ -24,9 +106,10 @@ pub enum IterStatus {
|
||||
#[pyclass(module = false, name = "iterator")]
|
||||
#[derive(Debug)]
|
||||
pub struct PySequenceIterator {
|
||||
pub position: AtomicCell<usize>,
|
||||
pub obj: PyObjectRef,
|
||||
pub status: AtomicCell<IterStatus>,
|
||||
internal: PositionIterInternal,
|
||||
// pub position: AtomicCell<usize>,
|
||||
// pub obj: PyObjectRef,
|
||||
// pub status: AtomicCell<IterStatus>
|
||||
}
|
||||
|
||||
impl PyValue for PySequenceIterator {
|
||||
@@ -39,75 +122,38 @@ impl PyValue for PySequenceIterator {
|
||||
impl PySequenceIterator {
|
||||
pub fn new(obj: PyObjectRef) -> Self {
|
||||
Self {
|
||||
position: AtomicCell::new(0),
|
||||
obj,
|
||||
status: AtomicCell::new(IterStatus::Active),
|
||||
internal: PositionIterInternal::new(obj),
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||||
match self.status.load() {
|
||||
IterStatus::Active => {
|
||||
let pos = self.position.load();
|
||||
// return NotImplemented if no length is around.
|
||||
vm.obj_len(&self.obj)
|
||||
.map_or(vm.ctx.not_implemented(), |len| {
|
||||
PyInt::from(len.saturating_sub(pos)).into_object(vm)
|
||||
})
|
||||
}
|
||||
IterStatus::Exhausted => PyInt::from(0).into_object(vm),
|
||||
}
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.internal.length_hint(
|
||||
|| {
|
||||
vm.obj_len(&self.internal.obj.read())
|
||||
.map_err(|_| vm.new_not_implemented_error("".to_owned()))
|
||||
},
|
||||
vm,
|
||||
)
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
|
||||
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
|
||||
Ok(match self.status.load() {
|
||||
IterStatus::Exhausted => vm
|
||||
.ctx
|
||||
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
|
||||
IterStatus::Active => vm.ctx.new_tuple(vec![
|
||||
iter,
|
||||
vm.ctx.new_tuple(vec![self.obj.clone()]),
|
||||
vm.ctx.new_int(self.position.load()),
|
||||
]),
|
||||
})
|
||||
Ok(self.internal.reduce(iter, vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
// When we're exhausted, just return.
|
||||
if let IterStatus::Exhausted = self.status.load() {
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(i) = state.payload::<PyInt>() {
|
||||
self.position
|
||||
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
|
||||
Ok(())
|
||||
} else {
|
||||
Err(vm.new_type_error("an integer is required.".to_owned()))
|
||||
}
|
||||
self.internal.set_state(state, vm)
|
||||
}
|
||||
}
|
||||
|
||||
impl IteratorIterable for PySequenceIterator {}
|
||||
impl SlotIterator for PySequenceIterator {
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
|
||||
if let IterStatus::Exhausted = zelf.status.load() {
|
||||
return Ok(PyIterReturn::StopIteration(None));
|
||||
}
|
||||
let pos = zelf.position.fetch_add(1);
|
||||
match zelf.obj.get_item(pos, vm) {
|
||||
Err(ref e)
|
||||
if e.isinstance(&vm.ctx.exceptions.index_error)
|
||||
|| e.isinstance(&vm.ctx.exceptions.stop_iteration) =>
|
||||
{
|
||||
zelf.status.store(IterStatus::Exhausted);
|
||||
Ok(PyIterReturn::StopIteration(None))
|
||||
}
|
||||
ret => ret.map(PyIterReturn::Return),
|
||||
}
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
zelf.internal
|
||||
.next(|pos| zelf.internal.obj.read().get_item(pos, vm), vm)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -310,9 +310,7 @@ impl Comparable for PyTuple {
|
||||
impl Iterable for PyTuple {
|
||||
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
Ok(PyTupleIterator {
|
||||
position: AtomicCell::new(0),
|
||||
status: AtomicCell::new(Active),
|
||||
tuple: zelf,
|
||||
internal: PositionIterInternal::new(zelf.into_object()),
|
||||
}
|
||||
.into_object(vm))
|
||||
}
|
||||
@@ -321,9 +319,7 @@ impl Iterable for PyTuple {
|
||||
#[pyclass(module = false, name = "tuple_iterator")]
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PyTupleIterator {
|
||||
position: AtomicCell<usize>,
|
||||
status: AtomicCell<IterStatus>,
|
||||
tuple: PyTupleRef,
|
||||
internal: PositionIterInternal,
|
||||
}
|
||||
|
||||
impl PyValue for PyTupleIterator {
|
||||
@@ -335,61 +331,40 @@ impl PyValue for PyTupleIterator {
|
||||
#[pyimpl(with(SlotIterator))]
|
||||
impl PyTupleIterator {
|
||||
#[pymethod(magic)]
|
||||
fn length_hint(&self) -> usize {
|
||||
match self.status.load() {
|
||||
Active => self.tuple.len().saturating_sub(self.position.load()),
|
||||
Exhausted => 0,
|
||||
}
|
||||
fn length_hint(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.internal.length_hint(
|
||||
|| Ok(self.internal.obj.read().payload::<PyTuple>().unwrap().len()),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
// When we're exhausted, just return.
|
||||
if let Exhausted = self.status.load() {
|
||||
return Ok(());
|
||||
}
|
||||
// Else, set to min of (pos, tuple_size).
|
||||
if let Some(i) = state.payload::<PyInt>() {
|
||||
let position = std::cmp::min(
|
||||
int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0),
|
||||
self.tuple.len(),
|
||||
);
|
||||
self.position.store(position);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(vm.new_type_error("an integer is required.".to_owned()))
|
||||
}
|
||||
self.internal.set_state(state, vm)
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
|
||||
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
|
||||
Ok(match self.status.load() {
|
||||
Exhausted => vm
|
||||
.ctx
|
||||
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
|
||||
Active => vm.ctx.new_tuple(vec![
|
||||
iter,
|
||||
vm.ctx.new_tuple(vec![self.tuple.clone().into_object()]),
|
||||
vm.ctx.new_int(self.position.load()),
|
||||
]),
|
||||
})
|
||||
Ok(self.internal.reduce(iter, vm))
|
||||
}
|
||||
}
|
||||
|
||||
impl IteratorIterable for PyTupleIterator {}
|
||||
impl SlotIterator for PyTupleIterator {
|
||||
fn next(zelf: &PyRef<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
|
||||
if let Exhausted = zelf.status.load() {
|
||||
return Ok(PyIterReturn::StopIteration(None));
|
||||
}
|
||||
let pos = zelf.position.fetch_add(1);
|
||||
if let Some(obj) = zelf.tuple.as_slice().get(pos) {
|
||||
Ok(PyIterReturn::Return(obj.clone()))
|
||||
} else {
|
||||
zelf.status.store(Exhausted);
|
||||
Ok(PyIterReturn::StopIteration(None))
|
||||
}
|
||||
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
|
||||
zelf.internal.next(
|
||||
|pos| {
|
||||
let tuple = zelf.internal.obj.read();
|
||||
let tuple = tuple.payload::<PyTuple>().unwrap();
|
||||
tuple
|
||||
.as_slice()
|
||||
.get(pos)
|
||||
.ok_or_else(|| vm.new_stop_iteration())
|
||||
.map(|x| x.clone())
|
||||
},
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user