diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 2adc29113..8d0dcc1ed 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -4,14 +4,14 @@ use num_rational::Ratio; use num_traits::{float::Float, pow, sign::Signed, ToPrimitive, Zero}; use super::objbytes; -use super::objint; +use super::objint::{self, PyIntRef}; use super::objstr::{self, PyStringRef}; use super::objtype::{self, PyClassRef}; -use crate::function::OptionalArg; +use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ - IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, TypeProtocol, + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -458,26 +458,10 @@ impl PyFloat { } #[pymethod(name = "__round__")] - fn round(&self, ndigits: OptionalArg, vm: &VirtualMachine) -> PyResult { - let ndigits = match ndigits { - OptionalArg::Missing => None, - OptionalArg::Present(ref value) => { - if !vm.get_none().is(value) { - if !objtype::isinstance(value, &vm.ctx.int_type()) { - return Err(vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - value.class().name - ))); - }; - // Only accept int type ndigits - let ndigits = objint::get_value(value); - Some(ndigits) - } else { - None - } - } - }; + fn round(&self, ndigits: OptionalOption, vm: &VirtualMachine) -> PyResult { + let ndigits = ndigits.flat_option(); if let Some(ndigits) = ndigits { + let ndigits = ndigits.as_bigint(); if ndigits.is_zero() { let fract = self.value.fract(); let value = if (fract.abs() - 0.5).abs() < std::f64::EPSILON {