Copy secrets module from CPython 3.8

This commit is contained in:
boris
2021-01-31 18:55:48 -05:00
parent a39a19fbe4
commit b03185f59e
2 changed files with 197 additions and 0 deletions

73
Lib/secrets.py vendored Normal file
View File

@@ -0,0 +1,73 @@
"""Generate cryptographically strong pseudo-random numbers suitable for
managing secrets such as account authentication, tokens, and similar.
See PEP 506 for more information.
https://www.python.org/dev/peps/pep-0506/
"""
__all__ = ['choice', 'randbelow', 'randbits', 'SystemRandom',
'token_bytes', 'token_hex', 'token_urlsafe',
'compare_digest',
]
import base64
import binascii
import os
from hmac import compare_digest
from random import SystemRandom
_sysrand = SystemRandom()
randbits = _sysrand.getrandbits
choice = _sysrand.choice
def randbelow(exclusive_upper_bound):
"""Return a random int in the range [0, n)."""
if exclusive_upper_bound <= 0:
raise ValueError("Upper bound must be positive.")
return _sysrand._randbelow(exclusive_upper_bound)
DEFAULT_ENTROPY = 32 # number of bytes to return by default
def token_bytes(nbytes=None):
"""Return a random byte string containing *nbytes* bytes.
If *nbytes* is ``None`` or not supplied, a reasonable
default is used.
>>> token_bytes(16) #doctest:+SKIP
b'\\xebr\\x17D*t\\xae\\xd4\\xe3S\\xb6\\xe2\\xebP1\\x8b'
"""
if nbytes is None:
nbytes = DEFAULT_ENTROPY
return os.urandom(nbytes)
def token_hex(nbytes=None):
"""Return a random text string, in hexadecimal.
The string has *nbytes* random bytes, each byte converted to two
hex digits. If *nbytes* is ``None`` or not supplied, a reasonable
default is used.
>>> token_hex(16) #doctest:+SKIP
'f9bf78b9a18ce6d46a0cd2b0b86df9da'
"""
return binascii.hexlify(token_bytes(nbytes)).decode('ascii')
def token_urlsafe(nbytes=None):
"""Return a random URL-safe text string, in Base64 encoding.
The string has *nbytes* random bytes. If *nbytes* is ``None``
or not supplied, a reasonable default is used.
>>> token_urlsafe(16) #doctest:+SKIP
'Drmhze6EPcv0fN_81Bj-nA'
"""
tok = token_bytes(nbytes)
return base64.urlsafe_b64encode(tok).rstrip(b'=').decode('ascii')

124
Lib/test/test_secrets.py Normal file
View File

@@ -0,0 +1,124 @@
"""Test the secrets module.
As most of the functions in secrets are thin wrappers around functions
defined elsewhere, we don't need to test them exhaustively.
"""
import secrets
import unittest
import string
# === Unit tests ===
class Compare_Digest_Tests(unittest.TestCase):
"""Test secrets.compare_digest function."""
def test_equal(self):
# Test compare_digest functionality with equal (byte/text) strings.
for s in ("a", "bcd", "xyz123"):
a = s*100
b = s*100
self.assertTrue(secrets.compare_digest(a, b))
self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
def test_unequal(self):
# Test compare_digest functionality with unequal (byte/text) strings.
self.assertFalse(secrets.compare_digest("abc", "abcd"))
self.assertFalse(secrets.compare_digest(b"abc", b"abcd"))
for s in ("x", "mn", "a1b2c3"):
a = s*100 + "q"
b = s*100 + "k"
self.assertFalse(secrets.compare_digest(a, b))
self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
def test_bad_types(self):
# Test that compare_digest raises with mixed types.
a = 'abcde'
b = a.encode('utf-8')
assert isinstance(a, str)
assert isinstance(b, bytes)
self.assertRaises(TypeError, secrets.compare_digest, a, b)
self.assertRaises(TypeError, secrets.compare_digest, b, a)
def test_bool(self):
# Test that compare_digest returns a bool.
self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool)
self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool)
class Random_Tests(unittest.TestCase):
"""Test wrappers around SystemRandom methods."""
def test_randbits(self):
# Test randbits.
errmsg = "randbits(%d) returned %d"
for numbits in (3, 12, 30):
for i in range(6):
n = secrets.randbits(numbits)
self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n))
def test_choice(self):
# Test choice.
items = [1, 2, 4, 8, 16, 32, 64]
for i in range(10):
self.assertTrue(secrets.choice(items) in items)
def test_randbelow(self):
# Test randbelow.
for i in range(2, 10):
self.assertIn(secrets.randbelow(i), range(i))
self.assertRaises(ValueError, secrets.randbelow, 0)
self.assertRaises(ValueError, secrets.randbelow, -1)
class Token_Tests(unittest.TestCase):
"""Test token functions."""
def test_token_defaults(self):
# Test that token_* functions handle default size correctly.
for func in (secrets.token_bytes, secrets.token_hex,
secrets.token_urlsafe):
with self.subTest(func=func):
name = func.__name__
try:
func()
except TypeError:
self.fail("%s cannot be called with no argument" % name)
try:
func(None)
except TypeError:
self.fail("%s cannot be called with None" % name)
size = secrets.DEFAULT_ENTROPY
self.assertEqual(len(secrets.token_bytes(None)), size)
self.assertEqual(len(secrets.token_hex(None)), 2*size)
def test_token_bytes(self):
# Test token_bytes.
for n in (1, 8, 17, 100):
with self.subTest(n=n):
self.assertIsInstance(secrets.token_bytes(n), bytes)
self.assertEqual(len(secrets.token_bytes(n)), n)
def test_token_hex(self):
# Test token_hex.
for n in (1, 12, 25, 90):
with self.subTest(n=n):
s = secrets.token_hex(n)
self.assertIsInstance(s, str)
self.assertEqual(len(s), 2*n)
self.assertTrue(all(c in string.hexdigits for c in s))
def test_token_urlsafe(self):
# Test token_urlsafe.
legal = string.ascii_letters + string.digits + '-_'
for n in (1, 11, 28, 76):
with self.subTest(n=n):
s = secrets.token_urlsafe(n)
self.assertIsInstance(s, str)
self.assertTrue(all(c in legal for c in s))
if __name__ == '__main__':
unittest.main()