diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index 425a3e3f2..1a3ec9a94 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -4,7 +4,10 @@ use crate::function::{Args, KwArgs, OptionalArg, PyFuncArgs}; use crate::obj::objdict::PyDictRef; use crate::obj::objtuple::PyTupleRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{Either, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + Either, IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, + TypeProtocol, +}; use crate::vm::VirtualMachine; use parking_lot::{ @@ -16,18 +19,20 @@ use std::io::Write; use std::time::Duration; use std::{fmt, thread}; +// PY_TIMEOUT_MAX is a value in microseconds #[cfg(not(target_os = "windows"))] -const PY_TIMEOUT_MAX: isize = std::isize::MAX; +const PY_TIMEOUT_MAX: isize = std::isize::MAX / 1_000; #[cfg(target_os = "windows")] -const PY_TIMEOUT_MAX: isize = 0xffffffff * 1_000_000; +const PY_TIMEOUT_MAX: isize = 0xffffffff * 1_000; -const TIMEOUT_MAX: f64 = (PY_TIMEOUT_MAX / 1_000_000_000) as f64; +// this is a value in seconds +const TIMEOUT_MAX: f64 = (PY_TIMEOUT_MAX / 1_000_000) as f64; #[derive(FromArgs)] struct AcquireArgs { #[pyarg(positional_or_keyword, default = "true")] - waitflag: bool, + blocking: bool, #[pyarg(positional_or_keyword, default = "Either::A(-1.0)")] timeout: Either, } @@ -39,7 +44,7 @@ macro_rules! acquire_lock_impl { Either::A(f) => f, Either::B(i) => i as f64, }; - match args.waitflag { + match args.blocking { true if timeout == -1.0 => { mu.lock(); Ok(true) @@ -48,7 +53,16 @@ macro_rules! acquire_lock_impl { Err(vm.new_value_error("timeout value must be positive".to_owned())) } true => { - // TODO: respect TIMEOUT_MAX here + // modified from std::time::Duration::from_secs_f64 to avoid a panic. + // TODO: put this in the Duration::try_from_object impl, maybe? + let micros = timeout * 1_000_000.0; + let nanos = timeout * 1_000_000_000.0; + if micros > PY_TIMEOUT_MAX as f64 || nanos < 0.0 || !nanos.is_finite() { + return Err(vm.new_overflow_error( + "timestamp too large to convert to Rust Duration".to_owned(), + )); + } + Ok(mu.try_lock_for(Duration::from_secs_f64(timeout))) } false if timeout != -1.0 => { @@ -59,6 +73,21 @@ macro_rules! acquire_lock_impl { } }}; } +macro_rules! repr_lock_impl { + ($zelf:expr) => {{ + let status = if $zelf.mu.is_locked() { + "locked" + } else { + "unlocked" + }; + format!( + "<{} {} object at {}>", + status, + $zelf.class().name, + $zelf.get_id() + ) + }}; +} #[pyclass(name = "lock")] struct PyLock { @@ -102,6 +131,11 @@ impl PyLock { fn locked(&self) -> bool { self.mu.is_locked() } + + #[pymethod(magic)] + fn repr(zelf: PyRef) -> String { + repr_lock_impl!(zelf) + } } type RawRMutex = RawReentrantMutex; @@ -149,6 +183,11 @@ impl PyRLock { fn exit(&self, _args: PyFuncArgs) { self.release() } + + #[pymethod(magic)] + fn repr(zelf: PyRef) -> String { + repr_lock_impl!(zelf) + } } fn thread_get_ident() -> u64 { @@ -195,12 +234,16 @@ fn thread_start_new_thread( } SENTINELS.with(|sents| { for lock in sents.replace(Default::default()) { - lock.release() + lock.mu.unlock() } - }) + }); + vm.state.thread_count.fetch_sub(1); }); - res.map(|handle| thread_to_id(&handle.thread())) - .map_err(|err| super::os::convert_io_error(vm, err)) + res.map(|handle| { + vm.state.thread_count.fetch_add(1); + thread_to_id(&handle.thread()) + }) + .map_err(|err| super::os::convert_io_error(vm, err)) } thread_local!(static SENTINELS: RefCell> = RefCell::default()); @@ -217,6 +260,10 @@ fn thread_stack_size(size: OptionalArg, vm: &VirtualMachine) -> usize { vm.state.stacksize.swap(size) } +fn thread_count(vm: &VirtualMachine) -> usize { + vm.state.thread_count.load() +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -228,6 +275,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "start_new_thread" => ctx.new_function(thread_start_new_thread), "_set_sentinel" => ctx.new_function(thread_set_sentinel), "stack_size" => ctx.new_function(thread_stack_size), + "_count" => ctx.new_function(thread_count), "error" => ctx.exceptions.runtime_error.clone(), "TIMEOUT_MAX" => ctx.new_float(TIMEOUT_MAX), }) diff --git a/vm/src/vm.rs b/vm/src/vm.rs index ebcf66afa..9bf5e2f70 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -76,6 +76,7 @@ pub struct PyGlobalState { pub stdlib_inits: HashMap, pub frozen: HashMap, pub stacksize: AtomicCell, + pub thread_count: AtomicCell, } pub const NSIG: usize = 64; @@ -207,6 +208,7 @@ impl VirtualMachine { stdlib_inits, frozen, stacksize: AtomicCell::new(0), + thread_count: AtomicCell::new(0), }), initialized: false, };