ToPyException for base64::DecodeError

This commit is contained in:
Jeong YunWon
2023-03-04 14:02:31 +09:00
parent 362be9f344
commit ff973caa67

View File

@@ -1,20 +1,23 @@
pub(crate) use decl::make_module;
pub(super) use decl::crc32;
pub(crate) use decl::make_module;
use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine};
const PAD: u8 = 61u8;
const MAXLINESIZE: usize = 76; // Excluding the CRLF
#[pymodule(name = "binascii")]
mod decl {
use super::{MAXLINESIZE, PAD};
use crate::vm::{
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
builtins::{PyIntRef, PyTypeRef},
convert::ToPyException,
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
PyResult, VirtualMachine,
};
use itertools::Itertools;
const MAXLINESIZE: usize = 76;
#[pyattr(name = "Error", once)]
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx.new_exception_type(
"binascii",
"Error",
@@ -62,7 +65,10 @@ mod decl {
fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
data.with_ref(|hex_bytes| {
if hex_bytes.len() % 2 != 0 {
return Err(new_binascii_error("Odd-length string".to_owned(), vm));
return Err(super::new_binascii_error(
"Odd-length string".to_owned(),
vm,
));
}
let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
@@ -70,7 +76,7 @@ mod decl {
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
unhex.push(n1 << 4 | n2);
} else {
return Err(new_binascii_error(
return Err(super::new_binascii_error(
"Non-hexadecimal digit found".to_owned(),
vm,
));
@@ -139,10 +145,6 @@ mod decl {
newline: bool,
}
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(error_type(vm), msg)
}
#[derive(FromArgs)]
struct A2bBase64Args {
#[pyarg(any)]
@@ -177,8 +179,6 @@ mod decl {
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
];
const PAD: u8 = 61u8;
let A2bBase64Args { s, strict_mode } = args;
s.with_ref(|b| {
if b.is_empty() {
@@ -228,52 +228,43 @@ mod decl {
pads = 0;
// Decode individual ASCII character
if quad_pos == 0 {
quad_pos = 1;
left_char = binary_char as u8;
} else if quad_pos == 1 {
quad_pos = 2;
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
left_char = (binary_char & 0x0f) as u8;
} else if quad_pos == 2 {
quad_pos = 3;
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
left_char = (binary_char & 0x03) as u8;
} else if quad_pos == 3 {
quad_pos = 0;
decoded.push((left_char << 6) | binary_char as u8);
left_char = 0;
match quad_pos {
0 => {
quad_pos = 1;
left_char = binary_char as u8;
}
1 => {
quad_pos = 2;
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
left_char = (binary_char & 0x0f) as u8;
}
2 => {
quad_pos = 3;
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
left_char = (binary_char & 0x03) as u8;
}
3 => {
quad_pos = 0;
decoded.push((left_char << 6) | binary_char as u8);
left_char = 0;
}
_ => unsafe {
// quad_pos is only assigned in this match statement to constants
std::hint::unreachable_unchecked()
},
}
}
return match quad_pos {
match quad_pos {
0 => Ok(decoded),
1 => Err(base64::DecodeError::InvalidLastSymbol(decoded.len() / 3 * 4 + 1, 0)),
_ => Err(base64::DecodeError::InvalidLength)
};
})
.map_err(|err| {
let python_error = match err {
base64::DecodeError::InvalidByte(0, PAD) => {
String::from("Leading padding not allowed")
}
base64::DecodeError::InvalidByte(_, PAD) => {
String::from("Discontinuous padding not allowed")
}
base64::DecodeError::InvalidByte(_, _) => {
String::from("Only base64 data is allowed")
}
base64::DecodeError::InvalidLastSymbol(_, PAD) => {
String::from("Excess data after padding")
}
base64::DecodeError::InvalidLastSymbol(length, _) => {
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
}
base64::DecodeError::InvalidLength => String::from("Incorrect padding"),
};
new_binascii_error(format!("error decoding base64: {python_error}"), vm)
1 => Err(base64::DecodeError::InvalidLastSymbol(
decoded.len() / 3 * 4 + 1,
0,
)),
_ => Err(base64::DecodeError::InvalidLength),
}
})
.map_err(|err| super::Base64DecodeError(err).to_pyexception(vm))
}
#[pyfunction]
@@ -738,3 +729,26 @@ mod decl {
})
}
}
struct Base64DecodeError(base64::DecodeError);
fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(decl::error_type(vm), msg)
}
impl ToPyException for Base64DecodeError {
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
use base64::DecodeError::*;
let message = match self.0 {
InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(),
InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(),
InvalidByte(_, _) => "Only base64 data is allowed".to_owned(),
InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(),
InvalidLastSymbol(length, _) => {
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
}
InvalidLength => "Incorrect padding".to_owned(),
};
new_binascii_error(format!("error decoding base64: {message}"), vm)
}
}