diff --git a/tests/snippets/ints.py b/tests/snippets/ints.py index 3db203a3e..baa145bff 100644 --- a/tests/snippets/ints.py +++ b/tests/snippets/ints.py @@ -49,6 +49,8 @@ with assert_raises(ZeroDivisionError): assert (-3).__rdivmod__(2) == (-1, -1) assert (2).__pow__(3) == 8 assert (10).__pow__(-1) == 0.1 +with assert_raises(ZeroDivisionError): + (0).__pow__(-1) assert (2).__rpow__(3) == 9 assert (10).__mod__(5) == 0 assert (10).__mod__(6) == 4 diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 16b784847..fbcc6fcc0 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -154,6 +154,15 @@ fn inner_gt_int(value: f64, other_int: &BigInt) -> bool { } } +pub fn float_pow(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { + if v1.is_zero() { + let msg = format!("{} cannot be raised to a negative power", v1); + Err(vm.new_zero_division_error(msg)) + } else { + v1.powf(v2).into_pyobject(vm) + } +} + #[pyimpl] #[allow(clippy::trivially_copy_pass_by_ref)] impl PyFloat { @@ -359,7 +368,7 @@ impl PyFloat { fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { try_float(&other, vm)?.map_or_else( || Ok(vm.ctx.not_implemented()), - |other| self.value.powf(other).into_pyobject(vm), + |other| float_pow(self.value, other, vm), ) } @@ -367,7 +376,7 @@ impl PyFloat { fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { try_float(&other, vm)?.map_or_else( || Ok(vm.ctx.not_implemented()), - |other| other.powf(self.value).into_pyobject(vm), + |other| float_pow(other, self.value, vm), ) } diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 43506b58d..61be01457 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -18,6 +18,7 @@ use crate::vm::VirtualMachine; use super::objbool::IntoPyBool; use super::objbyteinner::PyByteInner; use super::objbytes::PyBytes; +use super::objfloat; use super::objint; use super::objstr::{PyString, PyStringRef}; use super::objtype; @@ -118,12 +119,12 @@ impl_try_from_object_int!( #[allow(clippy::collapsible_if)] fn inner_pow(int1: &PyInt, int2: &PyInt, vm: &VirtualMachine) -> PyResult { - let result = if int2.value.is_negative() { + if int2.value.is_negative() { let v1 = int1.float(vm)?; let v2 = int2.float(vm)?; - vm.ctx.new_float(v1.pow(v2)) + objfloat::float_pow(v1, v2, vm) } else { - if let Some(v2) = int2.value.to_u64() { + Ok(if let Some(v2) = int2.value.to_u64() { vm.ctx.new_int(int1.value.pow(v2)) } else if int1.value.is_one() || int1.value.is_zero() { vm.ctx.new_int(int1.value.clone()) @@ -137,9 +138,8 @@ fn inner_pow(int1: &PyInt, int2: &PyInt, vm: &VirtualMachine) -> PyResult { // missing feature: BigInt exp // practically, exp over u64 is not possible to calculate anyway vm.ctx.not_implemented() - } - }; - Ok(result) + }) + } } fn inner_mod(int1: &PyInt, int2: &PyInt, vm: &VirtualMachine) -> PyResult {