diff --git a/Cargo.lock b/Cargo.lock index 979a47576..43a5d9d7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1984,6 +1984,7 @@ version = "0.2.0" dependencies = [ "ascii", "bitflags", + "bstr", "cfg-if", "hexf-parse", "itertools", diff --git a/common/Cargo.toml b/common/Cargo.toml index d98ec8e78..7707a63fa 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -13,6 +13,7 @@ threading = ["parking_lot"] [dependencies] ascii = { workspace = true } bitflags = { workspace = true } +bstr = { workspace = true } cfg-if = { workspace = true } itertools = { workspace = true } libc = { workspace = true } diff --git a/common/src/int.rs b/common/src/int.rs new file mode 100644 index 000000000..193d71435 --- /dev/null +++ b/common/src/int.rs @@ -0,0 +1,131 @@ +use bstr::ByteSlice; +use num_bigint::{BigInt, BigUint, Sign}; +use num_traits::{ToPrimitive, Zero}; + +pub fn bytes_to_int(lit: &[u8], mut base: u32) -> Option { + // split sign + let mut lit = lit.trim(); + let sign = match lit.first()? { + b'+' => Some(Sign::Plus), + b'-' => Some(Sign::Minus), + _ => None, + }; + if sign.is_some() { + lit = &lit[1..]; + } + + // split radix + let first = *lit.first()?; + let has_radix = if first == b'0' { + match base { + 0 => { + if let Some(parsed) = lit.get(1).and_then(detect_base) { + base = parsed; + true + } else { + if let [_first, ref others @ .., last] = lit { + let is_zero = + others.iter().all(|&c| c == b'0' || c == b'_') && *last == b'0'; + if !is_zero { + return None; + } + } + return Some(BigInt::zero()); + } + } + 16 => lit.get(1).map_or(false, |&b| matches!(b, b'x' | b'X')), + 2 => lit.get(1).map_or(false, |&b| matches!(b, b'b' | b'B')), + 8 => lit.get(1).map_or(false, |&b| matches!(b, b'o' | b'O')), + _ => false, + } + } else { + if base == 0 { + base = 10; + } + false + }; + if has_radix { + lit = &lit[2..]; + if lit.first()? == &b'_' { + lit = &lit[1..]; + } + } + + // remove zeroes + let mut last = *lit.first()?; + if last == b'0' { + let mut count = 0; + for &cur in &lit[1..] { + if cur == b'_' { + if last == b'_' { + return None; + } + } else if cur != b'0' { + break; + }; + count += 1; + last = cur; + } + let prefix_last = lit[count]; + lit = &lit[count + 1..]; + if lit.is_empty() && prefix_last == b'_' { + return None; + } + } + + // validate + for c in lit { + let c = *c; + if !(c.is_ascii_alphanumeric() || c == b'_') { + return None; + } + + if c == b'_' && last == b'_' { + return None; + } + + last = c; + } + if last == b'_' { + return None; + } + + // parse + let number = if lit.is_empty() { + BigInt::zero() + } else { + let uint = BigUint::parse_bytes(lit, base)?; + BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint) + }; + Some(number) +} + +#[inline] +pub fn detect_base(c: &u8) -> Option { + let base = match c { + b'x' | b'X' => 16, + b'b' | b'B' => 2, + b'o' | b'O' => 8, + _ => return None, + }; + Some(base) +} + +// num-bigint now returns Some(inf) for to_f64() in some cases, so just keep that the same for now +#[inline(always)] +pub fn bigint_to_finite_float(int: &BigInt) -> Option { + int.to_f64().filter(|f| f.is_finite()) +} + +#[test] +fn test_bytes_to_int() { + assert_eq!(bytes_to_int(&b"0b101"[..], 2).unwrap(), BigInt::from(5)); + assert_eq!(bytes_to_int(&b"0x_10"[..], 16).unwrap(), BigInt::from(16)); + assert_eq!(bytes_to_int(&b"0b"[..], 16).unwrap(), BigInt::from(11)); + assert_eq!(bytes_to_int(&b"+0b101"[..], 2).unwrap(), BigInt::from(5)); + assert_eq!(bytes_to_int(&b"0_0_0"[..], 10).unwrap(), BigInt::from(0)); + assert_eq!(bytes_to_int(&b"09_99"[..], 0), None); + assert_eq!(bytes_to_int(&b"000"[..], 0).unwrap(), BigInt::from(0)); + assert_eq!(bytes_to_int(&b"0_"[..], 0), None); + assert_eq!(bytes_to_int(&b"0_100"[..], 10).unwrap(), BigInt::from(100)); +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 6b6212ef0..250f8de21 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -17,6 +17,7 @@ pub mod encodings; pub mod float_ops; pub mod format; pub mod hash; +pub mod int; pub mod linked_list; pub mod lock; pub mod os; diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index e9c602b80..9193bb793 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -3,7 +3,11 @@ use crate::{ builtins::PyStrRef, bytesinner::PyBytesInner, class::PyClassImpl, - common::{format::FormatSpec, hash}, + common::{ + format::FormatSpec, + hash, + int::{bigint_to_finite_float, bytes_to_int}, + }, convert::{IntoPyException, ToPyObject, ToPyResult}, function::{ ArgByteOrder, ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue, @@ -14,8 +18,7 @@ use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, }; -use bstr::ByteSlice; -use num_bigint::{BigInt, BigUint, Sign}; +use num_bigint::{BigInt, Sign}; use num_integer::Integer; use num_rational::Ratio; use num_traits::{One, Pow, PrimInt, Signed, ToPrimitive, Zero}; @@ -836,138 +839,16 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult Option { - // split sign - let mut lit = lit.trim(); - let sign = match lit.first()? { - b'+' => Some(Sign::Plus), - b'-' => Some(Sign::Minus), - _ => None, - }; - if sign.is_some() { - lit = &lit[1..]; - } - - // split radix - let first = *lit.first()?; - let has_radix = if first == b'0' { - match base { - 0 => { - if let Some(parsed) = lit.get(1).and_then(detect_base) { - base = parsed; - true - } else { - if let [_first, ref others @ .., last] = lit { - let is_zero = - others.iter().all(|&c| c == b'0' || c == b'_') && *last == b'0'; - if !is_zero { - return None; - } - } - return Some(BigInt::zero()); - } - } - 16 => lit.get(1).map_or(false, |&b| matches!(b, b'x' | b'X')), - 2 => lit.get(1).map_or(false, |&b| matches!(b, b'b' | b'B')), - 8 => lit.get(1).map_or(false, |&b| matches!(b, b'o' | b'O')), - _ => false, - } - } else { - if base == 0 { - base = 10; - } - false - }; - if has_radix { - lit = &lit[2..]; - if lit.first()? == &b'_' { - lit = &lit[1..]; - } - } - - // remove zeroes - let mut last = *lit.first()?; - if last == b'0' { - let mut count = 0; - for &cur in &lit[1..] { - if cur == b'_' { - if last == b'_' { - return None; - } - } else if cur != b'0' { - break; - }; - count += 1; - last = cur; - } - let prefix_last = lit[count]; - lit = &lit[count + 1..]; - if lit.is_empty() && prefix_last == b'_' { - return None; - } - } - - // validate - for c in lit { - let c = *c; - if !(c.is_ascii_alphanumeric() || c == b'_') { - return None; - } - - if c == b'_' && last == b'_' { - return None; - } - - last = c; - } - if last == b'_' { - return None; - } - - // parse - Some(if lit.is_empty() { - BigInt::zero() - } else { - let uint = BigUint::parse_bytes(lit, base)?; - BigInt::from_biguint(sign.unwrap_or(Sign::Plus), uint) - }) -} - -fn detect_base(c: &u8) -> Option { - match c { - b'x' | b'X' => Some(16), - b'b' | b'B' => Some(2), - b'o' | b'O' => Some(8), - _ => None, - } -} - // Retrieve inner int value: pub(crate) fn get_value(obj: &PyObject) -> &BigInt { &obj.payload::().unwrap().value } pub fn try_to_float(int: &BigInt, vm: &VirtualMachine) -> PyResult { - i2f(int).ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_owned())) -} -// num-bigint now returns Some(inf) for to_f64() in some cases, so just keep that the same for now -fn i2f(int: &BigInt) -> Option { - int.to_f64().filter(|f| f.is_finite()) + bigint_to_finite_float(int) + .ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_owned())) } pub(crate) fn init(context: &Context) { PyInt::extend_class(context, context.types.int_type); } - -#[test] -fn test_bytes_to_int() { - assert_eq!(bytes_to_int(&b"0b101"[..], 2).unwrap(), BigInt::from(5)); - assert_eq!(bytes_to_int(&b"0x_10"[..], 16).unwrap(), BigInt::from(16)); - assert_eq!(bytes_to_int(&b"0b"[..], 16).unwrap(), BigInt::from(11)); - assert_eq!(bytes_to_int(&b"+0b101"[..], 2).unwrap(), BigInt::from(5)); - assert_eq!(bytes_to_int(&b"0_0_0"[..], 10).unwrap(), BigInt::from(0)); - assert_eq!(bytes_to_int(&b"09_99"[..], 0), None); - assert_eq!(bytes_to_int(&b"000"[..], 0).unwrap(), BigInt::from(0)); - assert_eq!(bytes_to_int(&b"0_"[..], 0), None); - assert_eq!(bytes_to_int(&b"0_100"[..], 10).unwrap(), BigInt::from(100)); -} diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index f9a0ffb02..12bc10e7b 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -2,6 +2,7 @@ use crate::{ builtins::{ int, type_::PointerSlot, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr, }, + common::int::bytes_to_int, function::ArgBytesLike, stdlib::warnings, AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, @@ -40,14 +41,17 @@ impl PyObject { pub fn try_int(&self, vm: &VirtualMachine) -> PyResult { fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { let base = 10; - match int::bytes_to_int(lit, base) { - Some(i) => Ok(PyInt::from(i).into_ref(vm)), - None => Err(vm.new_value_error(format!( + let i = bytes_to_int(lit, base).ok_or_else(|| { + let repr = match obj.repr(vm) { + Ok(repr) => repr, + Err(err) => return err, + }; + vm.new_value_error(format!( "invalid literal for int() with base {}: {}", - base, - obj.repr(vm)?, - ))), - } + base, repr, + )) + })?; + Ok(PyInt::from(i).into_ref(vm)) } if let Some(i) = self.downcast_ref_if_exact::(vm) {