diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index 920fc703da..6407409acf 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -425,62 +425,63 @@ impl PyInt { self.int_op(other, |a, b| a & b, vm) } + fn modpow(&self, other: PyObjectRef, modulus: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let modulus = match modulus.payload_if_subclass::(vm) { + Some(val) => val.as_bigint(), + None => return Ok(vm.ctx.not_implemented()), + }; + if modulus.is_zero() { + return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned())); + } + + self.general_op( + other, + |a, b| { + let i = if b.is_negative() { + // modular multiplicative inverse + // based on rust-num/num-integer#10, should hopefully be published soon + fn normalize(a: BigInt, n: &BigInt) -> BigInt { + let a = a % n; + if a.is_negative() { + a + n + } else { + a + } + } + fn inverse(a: BigInt, n: &BigInt) -> Option { + use num_integer::*; + let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n); + if gcd.is_one() { + Some(normalize(c, n)) + } else { + None + } + } + let a = inverse(a % modulus, modulus).ok_or_else(|| { + vm.new_value_error( + "base is not invertible for the given modulus".to_owned(), + ) + })?; + let b = -b; + a.modpow(&b, modulus) + } else { + a.modpow(b, modulus) + }; + Ok(vm.ctx.new_int(i).into()) + }, + vm, + ) + } + #[pymethod(magic)] fn pow( &self, other: PyObjectRef, - mod_val: OptionalOption, + r#mod: OptionalOption, vm: &VirtualMachine, ) -> PyResult { - match mod_val.flatten() { - Some(int_ref) => { - let int = match int_ref.payload_if_subclass::(vm) { - Some(val) => val, - None => return Ok(vm.ctx.not_implemented()), - }; - - let modulus = int.as_bigint(); - if modulus.is_zero() { - return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned())); - } - self.general_op( - other, - |a, b| { - let i = if b.is_negative() { - // modular multiplicative inverse - // based on rust-num/num-integer#10, should hopefully be published soon - fn normalize(a: BigInt, n: &BigInt) -> BigInt { - let a = a % n; - if a.is_negative() { - a + n - } else { - a - } - } - fn inverse(a: BigInt, n: &BigInt) -> Option { - use num_integer::*; - let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n); - if gcd.is_one() { - Some(normalize(c, n)) - } else { - None - } - } - let a = inverse(a % modulus, modulus).ok_or_else(|| { - vm.new_value_error( - "base is not invertible for the given modulus".to_owned(), - ) - })?; - let b = -b; - a.modpow(&b, modulus) - } else { - a.modpow(b, modulus) - }; - Ok(vm.ctx.new_int(i).into()) - }, - vm, - ) - } + match r#mod.flatten() { + Some(modulus) => self.modpow(other, modulus, vm), None => self.general_op(other, |a, b| inner_pow(a, b, vm), vm), } }