mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Merge pull request #4777 from MegasKomnenos/base64
make int.to_bytes arguments optional and update base64.py and test_base64.py from CPython v3.11.2
This commit is contained in:
136
Lib/base64.py
vendored
136
Lib/base64.py
vendored
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/python3.6
|
||||
#! /usr/bin/env python3
|
||||
|
||||
"""Base16, Base32, Base64 (RFC 3548), Base85 and Ascii85 data encodings"""
|
||||
|
||||
@@ -16,7 +16,7 @@ __all__ = [
|
||||
'encode', 'decode', 'encodebytes', 'decodebytes',
|
||||
# Generalized interface for other encodings
|
||||
'b64encode', 'b64decode', 'b32encode', 'b32decode',
|
||||
'b16encode', 'b16decode',
|
||||
'b32hexencode', 'b32hexdecode', 'b16encode', 'b16decode',
|
||||
# Base85 and Ascii85 encodings
|
||||
'b85encode', 'b85decode', 'a85encode', 'a85decode',
|
||||
# Standard Base64 encoding
|
||||
@@ -76,15 +76,16 @@ def b64decode(s, altchars=None, validate=False):
|
||||
normal base-64 alphabet nor the alternative alphabet are discarded prior
|
||||
to the padding check. If validate is True, these non-alphabet characters
|
||||
in the input result in a binascii.Error.
|
||||
For more information about the strict base64 check, see:
|
||||
|
||||
https://docs.python.org/3.11/library/binascii.html#binascii.a2b_base64
|
||||
"""
|
||||
s = _bytes_from_decode_data(s)
|
||||
if altchars is not None:
|
||||
altchars = _bytes_from_decode_data(altchars)
|
||||
assert len(altchars) == 2, repr(altchars)
|
||||
s = s.translate(bytes.maketrans(altchars, b'+/'))
|
||||
if validate and not re.match(b'^[A-Za-z0-9+/]*={0,2}$', s):
|
||||
raise binascii.Error('Non-base64 digit found')
|
||||
return binascii.a2b_base64(s)
|
||||
return binascii.a2b_base64(s, strict_mode=validate)
|
||||
|
||||
|
||||
def standard_b64encode(s):
|
||||
@@ -135,19 +136,40 @@ def urlsafe_b64decode(s):
|
||||
|
||||
|
||||
# Base32 encoding/decoding must be done in Python
|
||||
_b32alphabet = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'
|
||||
_b32tab2 = None
|
||||
_b32rev = None
|
||||
_B32_ENCODE_DOCSTRING = '''
|
||||
Encode the bytes-like objects using {encoding} and return a bytes object.
|
||||
'''
|
||||
_B32_DECODE_DOCSTRING = '''
|
||||
Decode the {encoding} encoded bytes-like object or ASCII string s.
|
||||
|
||||
def b32encode(s):
|
||||
"""Encode the bytes-like object s using Base32 and return a bytes object.
|
||||
"""
|
||||
Optional casefold is a flag specifying whether a lowercase alphabet is
|
||||
acceptable as input. For security purposes, the default is False.
|
||||
{extra_args}
|
||||
The result is returned as a bytes object. A binascii.Error is raised if
|
||||
the input is incorrectly padded or if there are non-alphabet
|
||||
characters present in the input.
|
||||
'''
|
||||
_B32_DECODE_MAP01_DOCSTRING = '''
|
||||
RFC 3548 allows for optional mapping of the digit 0 (zero) to the
|
||||
letter O (oh), and for optional mapping of the digit 1 (one) to
|
||||
either the letter I (eye) or letter L (el). The optional argument
|
||||
map01 when not None, specifies which letter the digit 1 should be
|
||||
mapped to (when map01 is not None, the digit 0 is always mapped to
|
||||
the letter O). For security purposes the default is None, so that
|
||||
0 and 1 are not allowed in the input.
|
||||
'''
|
||||
_b32alphabet = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'
|
||||
_b32hexalphabet = b'0123456789ABCDEFGHIJKLMNOPQRSTUV'
|
||||
_b32tab2 = {}
|
||||
_b32rev = {}
|
||||
|
||||
def _b32encode(alphabet, s):
|
||||
global _b32tab2
|
||||
# Delay the initialization of the table to not waste memory
|
||||
# if the function is never called
|
||||
if _b32tab2 is None:
|
||||
b32tab = [bytes((i,)) for i in _b32alphabet]
|
||||
_b32tab2 = [a + b for a in b32tab for b in b32tab]
|
||||
if alphabet not in _b32tab2:
|
||||
b32tab = [bytes((i,)) for i in alphabet]
|
||||
_b32tab2[alphabet] = [a + b for a in b32tab for b in b32tab]
|
||||
b32tab = None
|
||||
|
||||
if not isinstance(s, bytes_types):
|
||||
@@ -158,9 +180,9 @@ def b32encode(s):
|
||||
s = s + b'\0' * (5 - leftover) # Don't use += !
|
||||
encoded = bytearray()
|
||||
from_bytes = int.from_bytes
|
||||
b32tab2 = _b32tab2
|
||||
b32tab2 = _b32tab2[alphabet]
|
||||
for i in range(0, len(s), 5):
|
||||
c = from_bytes(s[i: i + 5], 'big')
|
||||
c = from_bytes(s[i: i + 5]) # big endian
|
||||
encoded += (b32tab2[c >> 30] + # bits 1 - 10
|
||||
b32tab2[(c >> 20) & 0x3ff] + # bits 11 - 20
|
||||
b32tab2[(c >> 10) & 0x3ff] + # bits 21 - 30
|
||||
@@ -177,29 +199,12 @@ def b32encode(s):
|
||||
encoded[-1:] = b'='
|
||||
return bytes(encoded)
|
||||
|
||||
def b32decode(s, casefold=False, map01=None):
|
||||
"""Decode the Base32 encoded bytes-like object or ASCII string s.
|
||||
|
||||
Optional casefold is a flag specifying whether a lowercase alphabet is
|
||||
acceptable as input. For security purposes, the default is False.
|
||||
|
||||
RFC 3548 allows for optional mapping of the digit 0 (zero) to the
|
||||
letter O (oh), and for optional mapping of the digit 1 (one) to
|
||||
either the letter I (eye) or letter L (el). The optional argument
|
||||
map01 when not None, specifies which letter the digit 1 should be
|
||||
mapped to (when map01 is not None, the digit 0 is always mapped to
|
||||
the letter O). For security purposes the default is None, so that
|
||||
0 and 1 are not allowed in the input.
|
||||
|
||||
The result is returned as a bytes object. A binascii.Error is raised if
|
||||
the input is incorrectly padded or if there are non-alphabet
|
||||
characters present in the input.
|
||||
"""
|
||||
def _b32decode(alphabet, s, casefold=False, map01=None):
|
||||
global _b32rev
|
||||
# Delay the initialization of the table to not waste memory
|
||||
# if the function is never called
|
||||
if _b32rev is None:
|
||||
_b32rev = {v: k for k, v in enumerate(_b32alphabet)}
|
||||
if alphabet not in _b32rev:
|
||||
_b32rev[alphabet] = {v: k for k, v in enumerate(alphabet)}
|
||||
s = _bytes_from_decode_data(s)
|
||||
if len(s) % 8:
|
||||
raise binascii.Error('Incorrect padding')
|
||||
@@ -220,7 +225,7 @@ def b32decode(s, casefold=False, map01=None):
|
||||
padchars = l - len(s)
|
||||
# Now decode the full quanta
|
||||
decoded = bytearray()
|
||||
b32rev = _b32rev
|
||||
b32rev = _b32rev[alphabet]
|
||||
for i in range(0, len(s), 8):
|
||||
quanta = s[i: i + 8]
|
||||
acc = 0
|
||||
@@ -229,18 +234,38 @@ def b32decode(s, casefold=False, map01=None):
|
||||
acc = (acc << 5) + b32rev[c]
|
||||
except KeyError:
|
||||
raise binascii.Error('Non-base32 digit found') from None
|
||||
decoded += acc.to_bytes(5, 'big')
|
||||
decoded += acc.to_bytes(5) # big endian
|
||||
# Process the last, partial quanta
|
||||
if l % 8 or padchars not in {0, 1, 3, 4, 6}:
|
||||
raise binascii.Error('Incorrect padding')
|
||||
if padchars and decoded:
|
||||
acc <<= 5 * padchars
|
||||
last = acc.to_bytes(5, 'big')
|
||||
last = acc.to_bytes(5) # big endian
|
||||
leftover = (43 - 5 * padchars) // 8 # 1: 4, 3: 3, 4: 2, 6: 1
|
||||
decoded[-5:] = last[:leftover]
|
||||
return bytes(decoded)
|
||||
|
||||
|
||||
def b32encode(s):
|
||||
return _b32encode(_b32alphabet, s)
|
||||
b32encode.__doc__ = _B32_ENCODE_DOCSTRING.format(encoding='base32')
|
||||
|
||||
def b32decode(s, casefold=False, map01=None):
|
||||
return _b32decode(_b32alphabet, s, casefold, map01)
|
||||
b32decode.__doc__ = _B32_DECODE_DOCSTRING.format(encoding='base32',
|
||||
extra_args=_B32_DECODE_MAP01_DOCSTRING)
|
||||
|
||||
def b32hexencode(s):
|
||||
return _b32encode(_b32hexalphabet, s)
|
||||
b32hexencode.__doc__ = _B32_ENCODE_DOCSTRING.format(encoding='base32hex')
|
||||
|
||||
def b32hexdecode(s, casefold=False):
|
||||
# base32hex does not have the 01 mapping
|
||||
return _b32decode(_b32hexalphabet, s, casefold)
|
||||
b32hexdecode.__doc__ = _B32_DECODE_DOCSTRING.format(encoding='base32hex',
|
||||
extra_args='')
|
||||
|
||||
|
||||
# RFC 3548, Base 16 Alphabet specifies uppercase, but hexlify() returns
|
||||
# lowercase. The RFC also recommends against accepting input case
|
||||
# insensitively.
|
||||
@@ -320,7 +345,7 @@ def a85encode(b, *, foldspaces=False, wrapcol=0, pad=False, adobe=False):
|
||||
global _a85chars, _a85chars2
|
||||
# Delay the initialization of tables to not waste memory
|
||||
# if the function is never called
|
||||
if _a85chars is None:
|
||||
if _a85chars2 is None:
|
||||
_a85chars = [bytes((i,)) for i in range(33, 118)]
|
||||
_a85chars2 = [(a + b) for a in _a85chars for b in _a85chars]
|
||||
|
||||
@@ -428,7 +453,7 @@ def b85encode(b, pad=False):
|
||||
global _b85chars, _b85chars2
|
||||
# Delay the initialization of tables to not waste memory
|
||||
# if the function is never called
|
||||
if _b85chars is None:
|
||||
if _b85chars2 is None:
|
||||
_b85chars = [bytes((i,)) for i in _b85alphabet]
|
||||
_b85chars2 = [(a + b) for a in _b85chars for b in _b85chars]
|
||||
return _85encode(b, _b85chars, _b85chars2, pad)
|
||||
@@ -531,42 +556,28 @@ def encodebytes(s):
|
||||
pieces.append(binascii.b2a_base64(chunk))
|
||||
return b"".join(pieces)
|
||||
|
||||
def encodestring(s):
|
||||
"""Legacy alias of encodebytes()."""
|
||||
import warnings
|
||||
warnings.warn("encodestring() is a deprecated alias since 3.1, "
|
||||
"use encodebytes()",
|
||||
DeprecationWarning, 2)
|
||||
return encodebytes(s)
|
||||
|
||||
|
||||
def decodebytes(s):
|
||||
"""Decode a bytestring of base-64 data into a bytes object."""
|
||||
_input_type_check(s)
|
||||
return binascii.a2b_base64(s)
|
||||
|
||||
def decodestring(s):
|
||||
"""Legacy alias of decodebytes()."""
|
||||
import warnings
|
||||
warnings.warn("decodestring() is a deprecated alias since Python 3.1, "
|
||||
"use decodebytes()",
|
||||
DeprecationWarning, 2)
|
||||
return decodebytes(s)
|
||||
|
||||
|
||||
# Usable as a script...
|
||||
def main():
|
||||
"""Small main program"""
|
||||
import sys, getopt
|
||||
usage = """usage: %s [-h|-d|-e|-u|-t] [file|-]
|
||||
-h: print this help message and exit
|
||||
-d, -u: decode
|
||||
-e: encode (default)
|
||||
-t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0]
|
||||
try:
|
||||
opts, args = getopt.getopt(sys.argv[1:], 'deut')
|
||||
opts, args = getopt.getopt(sys.argv[1:], 'hdeut')
|
||||
except getopt.error as msg:
|
||||
sys.stdout = sys.stderr
|
||||
print(msg)
|
||||
print("""usage: %s [-d|-e|-u|-t] [file|-]
|
||||
-d, -u: decode
|
||||
-e: encode (default)
|
||||
-t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0])
|
||||
print(usage)
|
||||
sys.exit(2)
|
||||
func = encode
|
||||
for o, a in opts:
|
||||
@@ -574,6 +585,7 @@ def main():
|
||||
if o == '-d': func = decode
|
||||
if o == '-u': func = decode
|
||||
if o == '-t': test(); return
|
||||
if o == '-h': print(usage); return
|
||||
if args and args[0] != '-':
|
||||
with open(args[0], 'rb') as f:
|
||||
func(f, sys.stdout.buffer)
|
||||
|
||||
134
Lib/test/test_base64.py
vendored
134
Lib/test/test_base64.py
vendored
@@ -1,10 +1,10 @@
|
||||
import unittest
|
||||
from test import support
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
from array import array
|
||||
from test.support import script_helper, os_helper
|
||||
from test.support import os_helper
|
||||
from test.support import script_helper
|
||||
|
||||
|
||||
class LegacyBase64TestCase(unittest.TestCase):
|
||||
@@ -18,14 +18,6 @@ class LegacyBase64TestCase(unittest.TestCase):
|
||||
int_data = memoryview(b"1234").cast('I')
|
||||
self.assertRaises(TypeError, f, int_data)
|
||||
|
||||
def test_encodestring_warns(self):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
base64.encodestring(b"www.python.org")
|
||||
|
||||
def test_decodestring_warns(self):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
base64.decodestring(b"d3d3LnB5dGhvbi5vcmc=\n")
|
||||
|
||||
def test_encodebytes(self):
|
||||
eq = self.assertEqual
|
||||
eq(base64.encodebytes(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=\n")
|
||||
@@ -129,6 +121,7 @@ class BaseXYTestCase(unittest.TestCase):
|
||||
int_data = memoryview(bytes_data).cast('I')
|
||||
self.assertEqual(f(int_data), f(bytes_data))
|
||||
|
||||
|
||||
def test_b64encode(self):
|
||||
eq = self.assertEqual
|
||||
# Test default alphabet
|
||||
@@ -239,8 +232,6 @@ class BaseXYTestCase(unittest.TestCase):
|
||||
self.assertRaises(binascii.Error, base64.b64decode, b'abc')
|
||||
self.assertRaises(binascii.Error, base64.b64decode, 'abc')
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
def test_b64decode_invalid_chars(self):
|
||||
# issue 1466065: Test some invalid characters.
|
||||
tests = ((b'%3d==', b'\xdd'),
|
||||
@@ -360,6 +351,76 @@ class BaseXYTestCase(unittest.TestCase):
|
||||
with self.assertRaises(binascii.Error):
|
||||
base64.b32decode(data.decode('ascii'))
|
||||
|
||||
def test_b32hexencode(self):
|
||||
test_cases = [
|
||||
# to_encode, expected
|
||||
(b'', b''),
|
||||
(b'\x00', b'00======'),
|
||||
(b'a', b'C4======'),
|
||||
(b'ab', b'C5H0===='),
|
||||
(b'abc', b'C5H66==='),
|
||||
(b'abcd', b'C5H66P0='),
|
||||
(b'abcde', b'C5H66P35'),
|
||||
]
|
||||
for to_encode, expected in test_cases:
|
||||
with self.subTest(to_decode=to_encode):
|
||||
self.assertEqual(base64.b32hexencode(to_encode), expected)
|
||||
|
||||
def test_b32hexencode_other_types(self):
|
||||
self.check_other_types(base64.b32hexencode, b'abcd', b'C5H66P0=')
|
||||
self.check_encode_type_errors(base64.b32hexencode)
|
||||
|
||||
def test_b32hexdecode(self):
|
||||
test_cases = [
|
||||
# to_decode, expected, casefold
|
||||
(b'', b'', False),
|
||||
(b'00======', b'\x00', False),
|
||||
(b'C4======', b'a', False),
|
||||
(b'C5H0====', b'ab', False),
|
||||
(b'C5H66===', b'abc', False),
|
||||
(b'C5H66P0=', b'abcd', False),
|
||||
(b'C5H66P35', b'abcde', False),
|
||||
(b'', b'', True),
|
||||
(b'00======', b'\x00', True),
|
||||
(b'C4======', b'a', True),
|
||||
(b'C5H0====', b'ab', True),
|
||||
(b'C5H66===', b'abc', True),
|
||||
(b'C5H66P0=', b'abcd', True),
|
||||
(b'C5H66P35', b'abcde', True),
|
||||
(b'c4======', b'a', True),
|
||||
(b'c5h0====', b'ab', True),
|
||||
(b'c5h66===', b'abc', True),
|
||||
(b'c5h66p0=', b'abcd', True),
|
||||
(b'c5h66p35', b'abcde', True),
|
||||
]
|
||||
for to_decode, expected, casefold in test_cases:
|
||||
with self.subTest(to_decode=to_decode, casefold=casefold):
|
||||
self.assertEqual(base64.b32hexdecode(to_decode, casefold),
|
||||
expected)
|
||||
self.assertEqual(base64.b32hexdecode(to_decode.decode('ascii'),
|
||||
casefold), expected)
|
||||
|
||||
def test_b32hexdecode_other_types(self):
|
||||
self.check_other_types(base64.b32hexdecode, b'C5H66===', b'abc')
|
||||
self.check_decode_type_errors(base64.b32hexdecode)
|
||||
|
||||
def test_b32hexdecode_error(self):
|
||||
tests = [b'abc', b'ABCDEF==', b'==ABCDEF', b'c4======']
|
||||
prefixes = [b'M', b'ME', b'MFRA', b'MFRGG', b'MFRGGZA', b'MFRGGZDF']
|
||||
for i in range(0, 17):
|
||||
if i:
|
||||
tests.append(b'='*i)
|
||||
for prefix in prefixes:
|
||||
if len(prefix) + i != 8:
|
||||
tests.append(prefix + b'='*i)
|
||||
for data in tests:
|
||||
with self.subTest(to_decode=data):
|
||||
with self.assertRaises(binascii.Error):
|
||||
base64.b32hexdecode(data)
|
||||
with self.assertRaises(binascii.Error):
|
||||
base64.b32hexdecode(data.decode('ascii'))
|
||||
|
||||
|
||||
def test_b16encode(self):
|
||||
eq = self.assertEqual
|
||||
eq(base64.b16encode(b'\x01\x02\xab\xcd\xef'), b'0102ABCDEF')
|
||||
@@ -653,6 +714,45 @@ class BaseXYTestCase(unittest.TestCase):
|
||||
def test_ErrorHeritage(self):
|
||||
self.assertTrue(issubclass(binascii.Error, ValueError))
|
||||
|
||||
def test_RFC4648_test_cases(self):
|
||||
# test cases from RFC 4648 section 10
|
||||
b64encode = base64.b64encode
|
||||
b32hexencode = base64.b32hexencode
|
||||
b32encode = base64.b32encode
|
||||
b16encode = base64.b16encode
|
||||
|
||||
self.assertEqual(b64encode(b""), b"")
|
||||
self.assertEqual(b64encode(b"f"), b"Zg==")
|
||||
self.assertEqual(b64encode(b"fo"), b"Zm8=")
|
||||
self.assertEqual(b64encode(b"foo"), b"Zm9v")
|
||||
self.assertEqual(b64encode(b"foob"), b"Zm9vYg==")
|
||||
self.assertEqual(b64encode(b"fooba"), b"Zm9vYmE=")
|
||||
self.assertEqual(b64encode(b"foobar"), b"Zm9vYmFy")
|
||||
|
||||
self.assertEqual(b32encode(b""), b"")
|
||||
self.assertEqual(b32encode(b"f"), b"MY======")
|
||||
self.assertEqual(b32encode(b"fo"), b"MZXQ====")
|
||||
self.assertEqual(b32encode(b"foo"), b"MZXW6===")
|
||||
self.assertEqual(b32encode(b"foob"), b"MZXW6YQ=")
|
||||
self.assertEqual(b32encode(b"fooba"), b"MZXW6YTB")
|
||||
self.assertEqual(b32encode(b"foobar"), b"MZXW6YTBOI======")
|
||||
|
||||
self.assertEqual(b32hexencode(b""), b"")
|
||||
self.assertEqual(b32hexencode(b"f"), b"CO======")
|
||||
self.assertEqual(b32hexencode(b"fo"), b"CPNG====")
|
||||
self.assertEqual(b32hexencode(b"foo"), b"CPNMU===")
|
||||
self.assertEqual(b32hexencode(b"foob"), b"CPNMUOG=")
|
||||
self.assertEqual(b32hexencode(b"fooba"), b"CPNMUOJ1")
|
||||
self.assertEqual(b32hexencode(b"foobar"), b"CPNMUOJ1E8======")
|
||||
|
||||
self.assertEqual(b16encode(b""), b"")
|
||||
self.assertEqual(b16encode(b"f"), b"66")
|
||||
self.assertEqual(b16encode(b"fo"), b"666F")
|
||||
self.assertEqual(b16encode(b"foo"), b"666F6F")
|
||||
self.assertEqual(b16encode(b"foob"), b"666F6F62")
|
||||
self.assertEqual(b16encode(b"fooba"), b"666F6F6261")
|
||||
self.assertEqual(b16encode(b"foobar"), b"666F6F626172")
|
||||
|
||||
|
||||
class TestMain(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -688,5 +788,15 @@ class TestMain(unittest.TestCase):
|
||||
output = self.get_output('-d', os_helper.TESTFN)
|
||||
self.assertEqual(output.rstrip(), b'a\xffb')
|
||||
|
||||
def test_prints_usage_with_help_flag(self):
|
||||
output = self.get_output('-h')
|
||||
self.assertIn(b'usage: ', output)
|
||||
self.assertIn(b'-d, -u: decode', output)
|
||||
|
||||
def test_prints_usage_with_invalid_flag(self):
|
||||
output = script_helper.assert_python_failure('-m', 'base64', '-x').err
|
||||
self.assertIn(b'usage: ', output)
|
||||
self.assertIn(b'-d, -u: decode', output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -614,7 +614,7 @@ impl PyInt {
|
||||
#[pymethod]
|
||||
fn to_bytes(&self, args: IntToByteArgs, vm: &VirtualMachine) -> PyResult<PyBytes> {
|
||||
let signed = args.signed.map_or(false, Into::into);
|
||||
let byte_len = args.length.try_to_primitive(vm)?;
|
||||
let byte_len = args.length;
|
||||
|
||||
let value = self.as_bigint();
|
||||
match value.sign() {
|
||||
@@ -802,7 +802,9 @@ struct IntFromByteArgs {
|
||||
|
||||
#[derive(FromArgs)]
|
||||
struct IntToByteArgs {
|
||||
length: PyIntRef,
|
||||
#[pyarg(any, default = "1")]
|
||||
length: usize,
|
||||
#[pyarg(any, default = "ArgByteOrder::Big")]
|
||||
byteorder: ArgByteOrder,
|
||||
#[pyarg(named, optional)]
|
||||
signed: OptionalArg<ArgIntoBool>,
|
||||
|
||||
Reference in New Issue
Block a user