mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Implement strict_mode keyword for binascii.a2b_base64
This commit is contained in:
committed by
Jeong YunWon
parent
d7f65cbbcd
commit
404c398b59
2
Lib/test/test_binascii.py
vendored
2
Lib/test/test_binascii.py
vendored
@@ -114,8 +114,6 @@ class BinASCIITest(unittest.TestCase):
|
||||
# empty strings. TBD: shouldn't it raise an exception instead ?
|
||||
self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'')
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
def test_base64_strict_mode(self):
|
||||
# Test base64 with strict mode on
|
||||
def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes):
|
||||
|
||||
@@ -2,13 +2,8 @@ pub(crate) use decl::make_module;
|
||||
|
||||
pub(super) use decl::crc32;
|
||||
|
||||
pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
|
||||
base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true))
|
||||
}
|
||||
|
||||
#[pymodule(name = "binascii")]
|
||||
mod decl {
|
||||
use super::decode;
|
||||
use crate::vm::{
|
||||
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
|
||||
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
|
||||
@@ -148,9 +143,20 @@ mod decl {
|
||||
vm.new_exception_msg(error_type(vm), msg)
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
struct A2bBase64Args {
|
||||
#[pyarg(any)]
|
||||
s: ArgAsciiBuffer,
|
||||
#[pyarg(named, default = "false")]
|
||||
strict_mode: bool,
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
|
||||
fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
|
||||
#[rustfmt::skip]
|
||||
// Converts between ASCII and base-64 characters. The index of a given number yields the
|
||||
// number in ASCII while the value of said index yields the number in base-64. For example
|
||||
// "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
|
||||
const BASE64_TABLE: [i8; 256] = [
|
||||
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
|
||||
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
|
||||
@@ -171,25 +177,107 @@ mod decl {
|
||||
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
|
||||
];
|
||||
|
||||
const PAD: u8 = 61u8;
|
||||
|
||||
let A2bBase64Args { s, strict_mode } = args;
|
||||
s.with_ref(|b| {
|
||||
let decoded = if b.len() % 4 == 0 {
|
||||
decode(b)
|
||||
} else {
|
||||
Err(base64::DecodeError::InvalidLength)
|
||||
};
|
||||
decoded.or_else(|_| {
|
||||
let buf: Vec<_> = b
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|&c| BASE64_TABLE[c as usize] != -1)
|
||||
.collect();
|
||||
if buf.len() % 4 != 0 {
|
||||
return Err(base64::DecodeError::InvalidLength);
|
||||
if b.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if strict_mode && b[0] == PAD {
|
||||
return Err(base64::DecodeError::InvalidByte(0, 61));
|
||||
}
|
||||
|
||||
let mut decoded: Vec<u8> = vec![];
|
||||
|
||||
let mut quad_pos = 0; // position in the nibble
|
||||
let mut pads = 0;
|
||||
let mut left_char: u8 = 0;
|
||||
let mut padding_started = false;
|
||||
for (i, &el) in b.iter().enumerate() {
|
||||
if el == PAD {
|
||||
padding_started = true;
|
||||
|
||||
pads += 1;
|
||||
if quad_pos >= 2 && quad_pos + pads >= 4 {
|
||||
if strict_mode && i + 1 < b.len() {
|
||||
// Represents excess data after padding error
|
||||
return Err(base64::DecodeError::InvalidLastSymbol(i, el));
|
||||
}
|
||||
|
||||
return Ok(decoded);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
decode(&buf)
|
||||
})
|
||||
|
||||
let binary_char = BASE64_TABLE[el as usize];
|
||||
if binary_char >= 64 || binary_char == -1 {
|
||||
if strict_mode {
|
||||
// Represents non-base64 data error
|
||||
return Err(base64::DecodeError::InvalidByte(i, el));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if strict_mode && padding_started {
|
||||
// Represents discontinuous padding error
|
||||
return Err(base64::DecodeError::InvalidByte(i, PAD));
|
||||
}
|
||||
pads = 0;
|
||||
|
||||
// Decode individual ASCII character
|
||||
if quad_pos == 0 {
|
||||
quad_pos = 1;
|
||||
left_char = binary_char as u8;
|
||||
} else if quad_pos == 1 {
|
||||
quad_pos = 2;
|
||||
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
|
||||
left_char = (binary_char & 0x0f) as u8;
|
||||
} else if quad_pos == 2 {
|
||||
quad_pos = 3;
|
||||
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
|
||||
left_char = (binary_char & 0x03) as u8;
|
||||
} else if quad_pos == 3 {
|
||||
quad_pos = 0;
|
||||
decoded.push((left_char << 6) | binary_char as u8);
|
||||
left_char = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if quad_pos == 1 {
|
||||
// Ensure that a PAD never gets passed, since that'd mistakenly cause an excess
|
||||
// data after padding error
|
||||
return Err(base64::DecodeError::InvalidLastSymbol(
|
||||
decoded.len() / 3 * 4 + 1,
|
||||
0,
|
||||
));
|
||||
} else if quad_pos > 1 {
|
||||
return Err(base64::DecodeError::InvalidLength);
|
||||
}
|
||||
|
||||
Ok(decoded)
|
||||
})
|
||||
.map_err(|err| {
|
||||
let python_error = match err {
|
||||
base64::DecodeError::InvalidByte(0, PAD) => {
|
||||
String::from("Leading padding not allowed")
|
||||
}
|
||||
base64::DecodeError::InvalidByte(_, PAD) => {
|
||||
String::from("Discontinuous padding not allowed")
|
||||
}
|
||||
base64::DecodeError::InvalidByte(_, _) => {
|
||||
String::from("Only base64 data is allowed")
|
||||
}
|
||||
base64::DecodeError::InvalidLastSymbol(_, _) => {
|
||||
String::from("Excess data after padding")
|
||||
}
|
||||
base64::DecodeError::InvalidLength => String::from("Not implemented (yet)"),
|
||||
};
|
||||
|
||||
new_binascii_error(format!("error decoding base64: {python_error}"), vm)
|
||||
})
|
||||
.map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
|
||||
Reference in New Issue
Block a user