cut down modpow to keep fast path smaller

This commit is contained in:
Jeong Yunwon
2022-05-04 08:00:52 +09:00
parent c303127912
commit 283c97cda7

View File

@@ -425,62 +425,63 @@ impl PyInt {
self.int_op(other, |a, b| a & b, vm)
}
fn modpow(&self, other: PyObjectRef, modulus: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let modulus = match modulus.payload_if_subclass::<PyInt>(vm) {
Some(val) => val.as_bigint(),
None => return Ok(vm.ctx.not_implemented()),
};
if modulus.is_zero() {
return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned()));
}
self.general_op(
other,
|a, b| {
let i = if b.is_negative() {
// modular multiplicative inverse
// based on rust-num/num-integer#10, should hopefully be published soon
fn normalize(a: BigInt, n: &BigInt) -> BigInt {
let a = a % n;
if a.is_negative() {
a + n
} else {
a
}
}
fn inverse(a: BigInt, n: &BigInt) -> Option<BigInt> {
use num_integer::*;
let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n);
if gcd.is_one() {
Some(normalize(c, n))
} else {
None
}
}
let a = inverse(a % modulus, modulus).ok_or_else(|| {
vm.new_value_error(
"base is not invertible for the given modulus".to_owned(),
)
})?;
let b = -b;
a.modpow(&b, modulus)
} else {
a.modpow(b, modulus)
};
Ok(vm.ctx.new_int(i).into())
},
vm,
)
}
#[pymethod(magic)]
fn pow(
&self,
other: PyObjectRef,
mod_val: OptionalOption<PyObjectRef>,
r#mod: OptionalOption<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
match mod_val.flatten() {
Some(int_ref) => {
let int = match int_ref.payload_if_subclass::<PyInt>(vm) {
Some(val) => val,
None => return Ok(vm.ctx.not_implemented()),
};
let modulus = int.as_bigint();
if modulus.is_zero() {
return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned()));
}
self.general_op(
other,
|a, b| {
let i = if b.is_negative() {
// modular multiplicative inverse
// based on rust-num/num-integer#10, should hopefully be published soon
fn normalize(a: BigInt, n: &BigInt) -> BigInt {
let a = a % n;
if a.is_negative() {
a + n
} else {
a
}
}
fn inverse(a: BigInt, n: &BigInt) -> Option<BigInt> {
use num_integer::*;
let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n);
if gcd.is_one() {
Some(normalize(c, n))
} else {
None
}
}
let a = inverse(a % modulus, modulus).ok_or_else(|| {
vm.new_value_error(
"base is not invertible for the given modulus".to_owned(),
)
})?;
let b = -b;
a.modpow(&b, modulus)
} else {
a.modpow(b, modulus)
};
Ok(vm.ctx.new_int(i).into())
},
vm,
)
}
match r#mod.flatten() {
Some(modulus) => self.modpow(other, modulus, vm),
None => self.general_op(other, |a, b| inner_pow(a, b, vm), vm),
}
}