diff --git a/vm/src/obj/objenumerate.rs b/vm/src/obj/objenumerate.rs index 4bbe8caccf..3c6c864c2e 100644 --- a/vm/src/obj/objenumerate.rs +++ b/vm/src/obj/objenumerate.rs @@ -1,5 +1,5 @@ -use std::cell::RefCell; use std::ops::AddAssign; +use std::sync::RwLock; use num_bigint::BigInt; use num_traits::Zero; @@ -8,16 +8,17 @@ use super::objint::PyIntRef; use super::objiter; use super::objtype::PyClassRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; use crate::vm::VirtualMachine; #[pyclass] #[derive(Debug)] pub struct PyEnumerate { - counter: RefCell, + counter: RwLock, iterator: PyObjectRef, } type PyEnumerateRef = PyRef; +impl ThreadSafe for PyEnumerate {} impl PyValue for PyEnumerate { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -41,24 +42,19 @@ impl PyEnumerate { let iterator = objiter::get_iter(vm, &iterable)?; PyEnumerate { - counter: RefCell::new(counter), + counter: RwLock::new(counter), iterator, } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let iterator = &self.iterator; - let counter = &self.counter; - let next_obj = objiter::call_next(vm, iterator)?; - let result = vm - .ctx - .new_tuple(vec![vm.ctx.new_bigint(&counter.borrow()), next_obj]); - - AddAssign::add_assign(&mut counter.borrow_mut() as &mut BigInt, 1); - - Ok(result) + fn next(&self, vm: &VirtualMachine) -> PyResult<(BigInt, PyObjectRef)> { + let next_obj = objiter::call_next(vm, &self.iterator)?; + let mut counter = self.counter.write().unwrap(); + let position = counter.clone(); + AddAssign::add_assign(&mut counter as &mut BigInt, 1); + Ok((position, next_obj)) } #[pymethod(name = "__iter__")]