diff --git a/vm/src/anystr.rs b/vm/src/anystr.rs index 4cd7a8f76..0611950ef 100644 --- a/vm/src/anystr.rs +++ b/vm/src/anystr.rs @@ -1,9 +1,9 @@ use crate::{ - builtins::PyIntRef, + builtins::{PyIntRef, PyTupleRef}, cformat::CFormatString, - function::{single_or_tuple_any, OptionalOption}, + function::OptionalOption, protocol::PyIterIter, - AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, + AsObject, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use num_traits::{cast::ToPrimitive, sign::Signed}; use std::str::FromStr; @@ -441,3 +441,33 @@ pub trait AnyStr<'s>: 's { .format(vm, values) } } + +/// Tests that the predicate is True on a single value, or if the value is a tuple a tuple, then +/// test that any of the values contained within the tuples satisfies the predicate. Type parameter +/// T specifies the type that is expected, if the input value is not of that type or a tuple of +/// values of that type, then a TypeError is raised. +pub fn single_or_tuple_any( + obj: PyObjectRef, + predicate: &F, + message: &M, + vm: &VirtualMachine, +) -> PyResult +where + T: TryFromObject, + F: Fn(&T) -> PyResult, + M: Fn(&PyObject) -> String, +{ + match T::try_from_object(vm, obj.clone()) { + Ok(single) => (predicate)(&single), + Err(_) => { + let tuple = PyTupleRef::try_from_object(vm, obj.clone()) + .map_err(|_| vm.new_type_error((message)(&obj)))?; + for obj in &tuple { + if single_or_tuple_any(obj.clone(), predicate, message, vm)? { + return Ok(true); + } + } + Ok(false) + } + } +} diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index 920fc703d..26944249d 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -5,7 +5,10 @@ use crate::{ common::hash, convert::{ToPyObject, ToPyResult}, format::FormatSpec, - function::{ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue, PyComparisonValue}, + function::{ + ArgByteOrder, ArgIntoBool, OptionalArg, OptionalOption, PyArithmeticValue, + PyComparisonValue, + }, types::{Comparable, Constructor, Hashable, PyComparisonOp}, AsObject, Context, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine, @@ -425,62 +428,63 @@ impl PyInt { self.int_op(other, |a, b| a & b, vm) } + fn modpow(&self, other: PyObjectRef, modulus: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let modulus = match modulus.payload_if_subclass::(vm) { + Some(val) => val.as_bigint(), + None => return Ok(vm.ctx.not_implemented()), + }; + if modulus.is_zero() { + return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned())); + } + + self.general_op( + other, + |a, b| { + let i = if b.is_negative() { + // modular multiplicative inverse + // based on rust-num/num-integer#10, should hopefully be published soon + fn normalize(a: BigInt, n: &BigInt) -> BigInt { + let a = a % n; + if a.is_negative() { + a + n + } else { + a + } + } + fn inverse(a: BigInt, n: &BigInt) -> Option { + use num_integer::*; + let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n); + if gcd.is_one() { + Some(normalize(c, n)) + } else { + None + } + } + let a = inverse(a % modulus, modulus).ok_or_else(|| { + vm.new_value_error( + "base is not invertible for the given modulus".to_owned(), + ) + })?; + let b = -b; + a.modpow(&b, modulus) + } else { + a.modpow(b, modulus) + }; + Ok(vm.ctx.new_int(i).into()) + }, + vm, + ) + } + #[pymethod(magic)] fn pow( &self, other: PyObjectRef, - mod_val: OptionalOption, + r#mod: OptionalOption, vm: &VirtualMachine, ) -> PyResult { - match mod_val.flatten() { - Some(int_ref) => { - let int = match int_ref.payload_if_subclass::(vm) { - Some(val) => val, - None => return Ok(vm.ctx.not_implemented()), - }; - - let modulus = int.as_bigint(); - if modulus.is_zero() { - return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned())); - } - self.general_op( - other, - |a, b| { - let i = if b.is_negative() { - // modular multiplicative inverse - // based on rust-num/num-integer#10, should hopefully be published soon - fn normalize(a: BigInt, n: &BigInt) -> BigInt { - let a = a % n; - if a.is_negative() { - a + n - } else { - a - } - } - fn inverse(a: BigInt, n: &BigInt) -> Option { - use num_integer::*; - let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n); - if gcd.is_one() { - Some(normalize(c, n)) - } else { - None - } - } - let a = inverse(a % modulus, modulus).ok_or_else(|| { - vm.new_value_error( - "base is not invertible for the given modulus".to_owned(), - ) - })?; - let b = -b; - a.modpow(&b, modulus) - } else { - a.modpow(b, modulus) - }; - Ok(vm.ctx.new_int(i).into()) - }, - vm, - ) - } + match r#mod.flatten() { + Some(modulus) => self.modpow(other, modulus, vm), None => self.general_op(other, |a, b| inner_pow(a, b, vm), vm), } } @@ -529,20 +533,13 @@ impl PyInt { match precision { OptionalArg::Missing => (), OptionalArg::Present(ref value) => { - if !vm.is_none(value) { - // Only accept int type ndigits - let _ndigits = value.payload_if_subclass::(vm).ok_or_else(|| { - vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - value.class().name() - )) - })?; - } else { - return Err(vm.new_type_error(format!( + // Only accept int type ndigits + let _ndigits = value.payload_if_subclass::(vm).ok_or_else(|| { + vm.new_type_error(format!( "'{}' object cannot be interpreted as an integer", value.class().name() - ))); - } + )) + })?; } } Ok(zelf) @@ -595,12 +592,9 @@ impl PyInt { #[pymethod(magic)] fn format(&self, spec: PyStrRef, vm: &VirtualMachine) -> PyResult { - match FormatSpec::parse(spec.as_str()) + FormatSpec::parse(spec.as_str()) .and_then(|format_spec| format_spec.format_int(&self.value)) - { - Ok(string) => Ok(string), - Err(err) => Err(vm.new_value_error(err.to_string())), - } + .map_err(|msg| vm.new_value_error(msg.to_owned())) } #[pymethod(magic)] @@ -635,15 +629,12 @@ impl PyInt { vm: &VirtualMachine, ) -> PyResult> { let signed = args.signed.map_or(false, Into::into); - let value = match (args.byteorder.as_str(), signed) { - ("big", true) => BigInt::from_signed_bytes_be(&args.bytes.elements), - ("big", false) => BigInt::from_bytes_be(Sign::Plus, &args.bytes.elements), - ("little", true) => BigInt::from_signed_bytes_le(&args.bytes.elements), - ("little", false) => BigInt::from_bytes_le(Sign::Plus, &args.bytes.elements), - _ => { - return Err( - vm.new_value_error("byteorder must be either 'little' or 'big'".to_owned()) - ) + let value = match (args.byteorder, signed) { + (ArgByteOrder::Big, true) => BigInt::from_signed_bytes_be(&args.bytes.elements), + (ArgByteOrder::Big, false) => BigInt::from_bytes_be(Sign::Plus, &args.bytes.elements), + (ArgByteOrder::Little, true) => BigInt::from_signed_bytes_le(&args.bytes.elements), + (ArgByteOrder::Little, false) => { + BigInt::from_bytes_le(Sign::Plus, &args.bytes.elements) } }; Self::with_value(cls, value, vm) @@ -665,16 +656,11 @@ impl PyInt { _ => {} } - let mut origin_bytes = match (args.byteorder.as_str(), signed) { - ("big", true) => value.to_signed_bytes_be(), - ("big", false) => value.to_bytes_be().1, - ("little", true) => value.to_signed_bytes_le(), - ("little", false) => value.to_bytes_le().1, - _ => { - return Err( - vm.new_value_error("byteorder must be either 'little' or 'big'".to_owned()) - ); - } + let mut origin_bytes = match (args.byteorder, signed) { + (ArgByteOrder::Big, true) => value.to_signed_bytes_be(), + (ArgByteOrder::Big, false) => value.to_bytes_be().1, + (ArgByteOrder::Little, true) => value.to_signed_bytes_le(), + (ArgByteOrder::Little, false) => value.to_bytes_le().1, }; let origin_len = origin_bytes.len(); @@ -687,21 +673,21 @@ impl PyInt { _ => vec![0u8; byte_len - origin_len], }; - let bytes = match args.byteorder.as_str() { - "big" => { + let bytes = match args.byteorder { + ArgByteOrder::Big => { let mut bytes = append_bytes; bytes.append(&mut origin_bytes); bytes } - "little" => { + ArgByteOrder::Little => { let mut bytes = origin_bytes; bytes.append(&mut append_bytes); bytes } - _ => Vec::new(), }; Ok(bytes.into()) } + #[pyproperty] fn real(&self, vm: &VirtualMachine) -> PyRef { // subclasses must return int here @@ -768,7 +754,7 @@ pub struct IntOptions { #[derive(FromArgs)] struct IntFromByteArgs { bytes: PyBytesInner, - byteorder: PyStrRef, + byteorder: ArgByteOrder, #[pyarg(named, optional)] signed: OptionalArg, } @@ -776,7 +762,7 @@ struct IntFromByteArgs { #[derive(FromArgs)] struct IntToByteArgs { length: PyIntRef, - byteorder: PyStrRef, + byteorder: ArgByteOrder, #[pyarg(named, optional)] signed: OptionalArg, } diff --git a/vm/src/function/mod.rs b/vm/src/function/mod.rs index 76f203c10..0091e16d9 100644 --- a/vm/src/function/mod.rs +++ b/vm/src/function/mod.rs @@ -17,36 +17,25 @@ pub use either::Either; pub use number::{ArgIntoBool, ArgIntoComplex, ArgIntoFloat}; pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence}; -use crate::{ - builtins::PyTupleRef, convert::TryFromObject, PyObject, PyObjectRef, PyResult, VirtualMachine, -}; +use crate::{builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyResult, VirtualMachine}; -/// Tests that the predicate is True on a single value, or if the value is a tuple a tuple, then -/// test that any of the values contained within the tuples satisfies the predicate. Type parameter -/// T specifies the type that is expected, if the input value is not of that type or a tuple of -/// values of that type, then a TypeError is raised. -pub fn single_or_tuple_any( - obj: PyObjectRef, - predicate: &F, - message: &M, - vm: &VirtualMachine, -) -> PyResult -where - T: TryFromObject, - F: Fn(&T) -> PyResult, - M: Fn(&PyObject) -> String, -{ - match T::try_from_object(vm, obj.clone()) { - Ok(single) => (predicate)(&single), - Err(_) => { - let tuple = PyTupleRef::try_from_object(vm, obj.clone()) - .map_err(|_| vm.new_type_error((message)(&obj)))?; - for obj in &tuple { - if single_or_tuple_any(obj.clone(), predicate, message, vm)? { - return Ok(true); +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ArgByteOrder { + Big, + Little, +} + +impl TryFromBorrowedObject for ArgByteOrder { + fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { + obj.try_value_with( + |s: &PyStr| match s.as_str() { + "big" => Ok(Self::Big), + "little" => Ok(Self::Little), + _ => { + Err(vm.new_value_error("byteorder must be either 'little' or 'big'".to_owned())) } - } - Ok(false) - } + }, + vm, + ) } }