From 404c398b5935b7426c5bbb0e01fa3350ac99ffcb Mon Sep 17 00:00:00 2001 From: Evan Rittenhouse Date: Wed, 18 Jan 2023 22:52:10 -0600 Subject: [PATCH] Implement `strict_mode` keyword for binascii.a2b_base64 --- Lib/test/test_binascii.py | 2 - stdlib/src/binascii.rs | 132 +++++++++++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 24 deletions(-) diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py index ef5c2e79e..ac02aed96 100644 --- a/Lib/test/test_binascii.py +++ b/Lib/test/test_binascii.py @@ -114,8 +114,6 @@ class BinASCIITest(unittest.TestCase): # empty strings. TBD: shouldn't it raise an exception instead ? self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_base64_strict_mode(self): # Test base64 with strict mode on def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes): diff --git a/stdlib/src/binascii.rs b/stdlib/src/binascii.rs index c9fd34877..c0c1b624a 100644 --- a/stdlib/src/binascii.rs +++ b/stdlib/src/binascii.rs @@ -2,13 +2,8 @@ pub(crate) use decl::make_module; pub(super) use decl::crc32; -pub fn decode>(input: T) -> Result, base64::DecodeError> { - base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true)) -} - #[pymodule(name = "binascii")] mod decl { - use super::decode; use crate::vm::{ builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef}, function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg}, @@ -148,9 +143,20 @@ mod decl { vm.new_exception_msg(error_type(vm), msg) } + #[derive(FromArgs)] + struct A2bBase64Args { + #[pyarg(any)] + s: ArgAsciiBuffer, + #[pyarg(named, default = "false")] + strict_mode: bool, + } + #[pyfunction] - fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult> { + fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult> { #[rustfmt::skip] + // Converts between ASCII and base-64 characters. The index of a given number yields the + // number in ASCII while the value of said index yields the number in base-64. For example + // "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0 const BASE64_TABLE: [i8; 256] = [ -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, @@ -171,25 +177,107 @@ mod decl { -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, ]; + const PAD: u8 = 61u8; + + let A2bBase64Args { s, strict_mode } = args; s.with_ref(|b| { - let decoded = if b.len() % 4 == 0 { - decode(b) - } else { - Err(base64::DecodeError::InvalidLength) - }; - decoded.or_else(|_| { - let buf: Vec<_> = b - .iter() - .copied() - .filter(|&c| BASE64_TABLE[c as usize] != -1) - .collect(); - if buf.len() % 4 != 0 { - return Err(base64::DecodeError::InvalidLength); + if b.is_empty() { + return Ok(vec![]); + } + + if strict_mode && b[0] == PAD { + return Err(base64::DecodeError::InvalidByte(0, 61)); + } + + let mut decoded: Vec = vec![]; + + let mut quad_pos = 0; // position in the nibble + let mut pads = 0; + let mut left_char: u8 = 0; + let mut padding_started = false; + for (i, &el) in b.iter().enumerate() { + if el == PAD { + padding_started = true; + + pads += 1; + if quad_pos >= 2 && quad_pos + pads >= 4 { + if strict_mode && i + 1 < b.len() { + // Represents excess data after padding error + return Err(base64::DecodeError::InvalidLastSymbol(i, el)); + } + + return Ok(decoded); + } + + continue; } - decode(&buf) - }) + + let binary_char = BASE64_TABLE[el as usize]; + if binary_char >= 64 || binary_char == -1 { + if strict_mode { + // Represents non-base64 data error + return Err(base64::DecodeError::InvalidByte(i, el)); + } + continue; + } + + if strict_mode && padding_started { + // Represents discontinuous padding error + return Err(base64::DecodeError::InvalidByte(i, PAD)); + } + pads = 0; + + // Decode individual ASCII character + if quad_pos == 0 { + quad_pos = 1; + left_char = binary_char as u8; + } else if quad_pos == 1 { + quad_pos = 2; + decoded.push((left_char << 2) | (binary_char >> 4) as u8); + left_char = (binary_char & 0x0f) as u8; + } else if quad_pos == 2 { + quad_pos = 3; + decoded.push((left_char << 4) | (binary_char >> 2) as u8); + left_char = (binary_char & 0x03) as u8; + } else if quad_pos == 3 { + quad_pos = 0; + decoded.push((left_char << 6) | binary_char as u8); + left_char = 0; + } + } + + if quad_pos == 1 { + // Ensure that a PAD never gets passed, since that'd mistakenly cause an excess + // data after padding error + return Err(base64::DecodeError::InvalidLastSymbol( + decoded.len() / 3 * 4 + 1, + 0, + )); + } else if quad_pos > 1 { + return Err(base64::DecodeError::InvalidLength); + } + + Ok(decoded) + }) + .map_err(|err| { + let python_error = match err { + base64::DecodeError::InvalidByte(0, PAD) => { + String::from("Leading padding not allowed") + } + base64::DecodeError::InvalidByte(_, PAD) => { + String::from("Discontinuous padding not allowed") + } + base64::DecodeError::InvalidByte(_, _) => { + String::from("Only base64 data is allowed") + } + base64::DecodeError::InvalidLastSymbol(_, _) => { + String::from("Excess data after padding") + } + base64::DecodeError::InvalidLength => String::from("Not implemented (yet)"), + }; + + new_binascii_error(format!("error decoding base64: {python_error}"), vm) }) - .map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm)) } #[pyfunction]