diff --git a/stdlib/src/binascii.rs b/stdlib/src/binascii.rs index 1843daace..2dca65deb 100644 --- a/stdlib/src/binascii.rs +++ b/stdlib/src/binascii.rs @@ -1,20 +1,23 @@ -pub(crate) use decl::make_module; - pub(super) use decl::crc32; +pub(crate) use decl::make_module; +use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine}; + +const PAD: u8 = 61u8; +const MAXLINESIZE: usize = 76; // Excluding the CRLF #[pymodule(name = "binascii")] mod decl { + use super::{MAXLINESIZE, PAD}; use crate::vm::{ - builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef}, + builtins::{PyIntRef, PyTypeRef}, + convert::ToPyException, function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg}, PyResult, VirtualMachine, }; use itertools::Itertools; - const MAXLINESIZE: usize = 76; - #[pyattr(name = "Error", once)] - fn error_type(vm: &VirtualMachine) -> PyTypeRef { + pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef { vm.ctx.new_exception_type( "binascii", "Error", @@ -62,7 +65,10 @@ mod decl { fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult> { data.with_ref(|hex_bytes| { if hex_bytes.len() % 2 != 0 { - return Err(new_binascii_error("Odd-length string".to_owned(), vm)); + return Err(super::new_binascii_error( + "Odd-length string".to_owned(), + vm, + )); } let mut unhex = Vec::::with_capacity(hex_bytes.len() / 2); @@ -70,7 +76,7 @@ mod decl { if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) { unhex.push(n1 << 4 | n2); } else { - return Err(new_binascii_error( + return Err(super::new_binascii_error( "Non-hexadecimal digit found".to_owned(), vm, )); @@ -139,10 +145,6 @@ mod decl { newline: bool, } - fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg(error_type(vm), msg) - } - #[derive(FromArgs)] struct A2bBase64Args { #[pyarg(any)] @@ -177,8 +179,6 @@ mod decl { -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, ]; - const PAD: u8 = 61u8; - let A2bBase64Args { s, strict_mode } = args; s.with_ref(|b| { if b.is_empty() { @@ -228,52 +228,43 @@ mod decl { pads = 0; // Decode individual ASCII character - if quad_pos == 0 { - quad_pos = 1; - left_char = binary_char as u8; - } else if quad_pos == 1 { - quad_pos = 2; - decoded.push((left_char << 2) | (binary_char >> 4) as u8); - left_char = (binary_char & 0x0f) as u8; - } else if quad_pos == 2 { - quad_pos = 3; - decoded.push((left_char << 4) | (binary_char >> 2) as u8); - left_char = (binary_char & 0x03) as u8; - } else if quad_pos == 3 { - quad_pos = 0; - decoded.push((left_char << 6) | binary_char as u8); - left_char = 0; + match quad_pos { + 0 => { + quad_pos = 1; + left_char = binary_char as u8; + } + 1 => { + quad_pos = 2; + decoded.push((left_char << 2) | (binary_char >> 4) as u8); + left_char = (binary_char & 0x0f) as u8; + } + 2 => { + quad_pos = 3; + decoded.push((left_char << 4) | (binary_char >> 2) as u8); + left_char = (binary_char & 0x03) as u8; + } + 3 => { + quad_pos = 0; + decoded.push((left_char << 6) | binary_char as u8); + left_char = 0; + } + _ => unsafe { + // quad_pos is only assigned in this match statement to constants + std::hint::unreachable_unchecked() + }, } } - return match quad_pos { + match quad_pos { 0 => Ok(decoded), - 1 => Err(base64::DecodeError::InvalidLastSymbol(decoded.len() / 3 * 4 + 1, 0)), - _ => Err(base64::DecodeError::InvalidLength) - }; - }) - .map_err(|err| { - let python_error = match err { - base64::DecodeError::InvalidByte(0, PAD) => { - String::from("Leading padding not allowed") - } - base64::DecodeError::InvalidByte(_, PAD) => { - String::from("Discontinuous padding not allowed") - } - base64::DecodeError::InvalidByte(_, _) => { - String::from("Only base64 data is allowed") - } - base64::DecodeError::InvalidLastSymbol(_, PAD) => { - String::from("Excess data after padding") - } - base64::DecodeError::InvalidLastSymbol(length, _) => { - format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length) - } - base64::DecodeError::InvalidLength => String::from("Incorrect padding"), - }; - - new_binascii_error(format!("error decoding base64: {python_error}"), vm) + 1 => Err(base64::DecodeError::InvalidLastSymbol( + decoded.len() / 3 * 4 + 1, + 0, + )), + _ => Err(base64::DecodeError::InvalidLength), + } }) + .map_err(|err| super::Base64DecodeError(err).to_pyexception(vm)) } #[pyfunction] @@ -738,3 +729,26 @@ mod decl { }) } } + +struct Base64DecodeError(base64::DecodeError); + +fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg(decl::error_type(vm), msg) +} + +impl ToPyException for Base64DecodeError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + use base64::DecodeError::*; + let message = match self.0 { + InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(), + InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(), + InvalidByte(_, _) => "Only base64 data is allowed".to_owned(), + InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(), + InvalidLastSymbol(length, _) => { + format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length) + } + InvalidLength => "Incorrect padding".to_owned(), + }; + new_binascii_error(format!("error decoding base64: {message}"), vm) + } +}