diff --git a/crates/stdlib/src/_sqlite3.rs b/crates/stdlib/src/_sqlite3.rs index 6855bf5d0..0a889c4d1 100644 --- a/crates/stdlib/src/_sqlite3.rs +++ b/crates/stdlib/src/_sqlite3.rs @@ -61,8 +61,8 @@ mod _sqlite3 { }, convert::IntoObject, function::{ - ArgCallable, ArgIterable, Either, FsPath, FuncArgs, OptionalArg, PyComparisonValue, - PySetterValue, + ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue, + PySetterValue, TimeoutSeconds, }, object::{Traverse, TraverseFn}, protocol::{ @@ -333,8 +333,8 @@ mod _sqlite3 { struct ConnectArgs { #[pyarg(any)] database: FsPath, - #[pyarg(any, default = Either::A(5.0))] - timeout: Either, + #[pyarg(any, default = TimeoutSeconds::new(5.0))] + timeout: TimeoutSeconds, #[pyarg(any, default = 0)] detect_types: c_int, #[pyarg(any, default = Some(vm.ctx.empty_str.to_owned()))] @@ -991,10 +991,7 @@ mod _sqlite3 { fn initialize_db(args: &ConnectArgs, vm: &VirtualMachine) -> PyResult { let path = args.database.to_cstring(vm)?; let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?); - let timeout = (match args.timeout { - Either::A(float) => float, - Either::B(int) => int as f64, - } * 1000.0) as c_int; + let timeout = (args.timeout.to_secs_f64() * 1000.0) as c_int; db.busy_timeout(timeout); if let Some(isolation_level) = &args.isolation_level { begin_statement_ptr_from_isolation_level(isolation_level, vm)?; diff --git a/crates/vm/src/function/mod.rs b/crates/vm/src/function/mod.rs index 150489195..4be94e3f0 100644 --- a/crates/vm/src/function/mod.rs +++ b/crates/vm/src/function/mod.rs @@ -8,6 +8,7 @@ mod getset; mod method; mod number; mod protocol; +mod time; pub use argument::{ ArgumentError, FromArgOptional, FromArgs, FuncArgs, IntoFuncArgs, KwArgs, OptionalArg, @@ -23,6 +24,7 @@ pub(super) use getset::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySett pub use method::{HeapMethodDef, PyMethodDef, PyMethodFlags}; pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgPrimitiveIndex, ArgSize}; pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence}; +pub use time::TimeoutSeconds; use crate::{PyObject, PyResult, VirtualMachine, builtins::PyStr, convert::TryFromBorrowedObject}; use builtin::{BorrowedParam, OwnedParam, RefParam}; diff --git a/crates/vm/src/function/time.rs b/crates/vm/src/function/time.rs new file mode 100644 index 000000000..29f14495d --- /dev/null +++ b/crates/vm/src/function/time.rs @@ -0,0 +1,34 @@ +use crate::{PyObjectRef, PyResult, TryFromObject, VirtualMachine}; + +/// A Python timeout value that accepts both `float` and `int`. +/// +/// `TimeoutSeconds` implements `FromArgs` so that a built-in function can accept +/// timeout parameters given as either `float` or `int`, normalizing them to `f64`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TimeoutSeconds { + value: f64, +} + +impl TimeoutSeconds { + pub const fn new(secs: f64) -> Self { + Self { value: secs } + } + + #[inline] + pub fn to_secs_f64(self) -> f64 { + self.value + } +} + +impl TryFromObject for TimeoutSeconds { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let value = match super::Either::::try_from_object(vm, obj)? { + super::Either::A(f) => f, + super::Either::B(i) => i as f64, + }; + if value.is_nan() { + return Err(vm.new_value_error("Invalid value NaN (not a number)".to_owned())); + } + Ok(Self { value }) + } +} diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs index f6849a116..765f25374 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/thread.rs @@ -15,7 +15,7 @@ pub(crate) mod _thread { builtins::{PyDictRef, PyStr, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef}, common::wtf8::Wtf8Buf, frame::FrameRef, - function::{ArgCallable, Either, FuncArgs, KwArgs, OptionalArg, PySetterValue}, + function::{ArgCallable, FuncArgs, KwArgs, OptionalArg, PySetterValue, TimeoutSeconds}, types::{Constructor, GetAttr, Representable, SetAttr}, }; use alloc::{ @@ -65,33 +65,26 @@ pub(crate) mod _thread { struct AcquireArgs { #[pyarg(any, default = true)] blocking: bool, - #[pyarg(any, default = Either::A(-1.0))] - timeout: Either, + #[pyarg(any, default = TimeoutSeconds::new(-1.0))] + timeout: TimeoutSeconds, } macro_rules! acquire_lock_impl { ($mu:expr, $args:expr, $vm:expr) => {{ let (mu, args, vm) = ($mu, $args, $vm); - let timeout = match args.timeout { - Either::A(f) => f, - Either::B(i) => i as f64, - }; + let timeout = args.timeout.to_secs_f64(); match args.blocking { true if timeout == -1.0 => { vm.allow_threads(|| mu.lock()); Ok(true) } true if timeout < 0.0 => { - Err(vm.new_value_error("timeout value must be positive".to_owned())) + Err(vm + .new_value_error("timeout value must be a non-negative number".to_owned())) } true => { - // modified from std::time::Duration::from_secs_f64 to avoid a panic. - // TODO: put this in the Duration::try_from_object impl, maybe? - let nanos = timeout * 1_000_000_000.0; - if timeout > TIMEOUT_MAX as f64 || nanos < 0.0 || !nanos.is_finite() { - return Err(vm.new_overflow_error( - "timestamp too large to convert to Rust Duration".to_owned(), - )); + if timeout > TIMEOUT_MAX { + return Err(vm.new_overflow_error("timeout value is too large".to_owned())); } Ok(vm.allow_threads(|| mu.try_lock_for(Duration::from_secs_f64(timeout))))