From bed2ddee4687014c658012ca112ed7dc3eecad4e Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 27 Jan 2021 21:01:41 -0600 Subject: [PATCH] Update ssl.py to version from CPython 3.8 --- Lib/ssl.py | 789 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 519 insertions(+), 270 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index b9fa7933c..1200d7d99 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -90,9 +90,6 @@ ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY """ -import ipaddress -import textwrap -import re import sys import os from collections import namedtuple @@ -100,12 +97,11 @@ from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag import _ssl # if we can't import it, let the error propagate -# XXX RustPython TODO: provide more of these imports from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION -from _ssl import _SSLContext #, MemoryBIO, SSLSession +from _ssl import _SSLContext#, MemoryBIO, SSLSession from _ssl import ( - SSLError, #SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, -# SSLSyscallError, SSLEOFError, + SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, + SSLSyscallError, SSLEOFError, SSLCertVerificationError ) from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj from _ssl import RAND_status, RAND_add, RAND_bytes, RAND_pseudo_bytes @@ -116,8 +112,11 @@ except ImportError: pass -# from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3 -# from _ssl import _OPENSSL_API_VERSION +from _ssl import ( + HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1, + HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3 +) +from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION _IntEnum._convert_( @@ -150,18 +149,112 @@ _IntEnum._convert_( lambda name: name.startswith('CERT_'), source=_ssl) - PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} _SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None) +class TLSVersion(_IntEnum): + MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED + SSLv3 = _ssl.PROTO_SSLv3 + TLSv1 = _ssl.PROTO_TLSv1 + TLSv1_1 = _ssl.PROTO_TLSv1_1 + TLSv1_2 = _ssl.PROTO_TLSv1_2 + TLSv1_3 = _ssl.PROTO_TLSv1_3 + MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED + + +class _TLSContentType(_IntEnum): + """Content types (record layer) + + See RFC 8446, section B.1 + """ + CHANGE_CIPHER_SPEC = 20 + ALERT = 21 + HANDSHAKE = 22 + APPLICATION_DATA = 23 + # pseudo content types + HEADER = 0x100 + INNER_CONTENT_TYPE = 0x101 + + +class _TLSAlertType(_IntEnum): + """Alert types for TLSContentType.ALERT messages + + See RFC 8466, section B.2 + """ + CLOSE_NOTIFY = 0 + UNEXPECTED_MESSAGE = 10 + BAD_RECORD_MAC = 20 + DECRYPTION_FAILED = 21 + RECORD_OVERFLOW = 22 + DECOMPRESSION_FAILURE = 30 + HANDSHAKE_FAILURE = 40 + NO_CERTIFICATE = 41 + BAD_CERTIFICATE = 42 + UNSUPPORTED_CERTIFICATE = 43 + CERTIFICATE_REVOKED = 44 + CERTIFICATE_EXPIRED = 45 + CERTIFICATE_UNKNOWN = 46 + ILLEGAL_PARAMETER = 47 + UNKNOWN_CA = 48 + ACCESS_DENIED = 49 + DECODE_ERROR = 50 + DECRYPT_ERROR = 51 + EXPORT_RESTRICTION = 60 + PROTOCOL_VERSION = 70 + INSUFFICIENT_SECURITY = 71 + INTERNAL_ERROR = 80 + INAPPROPRIATE_FALLBACK = 86 + USER_CANCELED = 90 + NO_RENEGOTIATION = 100 + MISSING_EXTENSION = 109 + UNSUPPORTED_EXTENSION = 110 + CERTIFICATE_UNOBTAINABLE = 111 + UNRECOGNIZED_NAME = 112 + BAD_CERTIFICATE_STATUS_RESPONSE = 113 + BAD_CERTIFICATE_HASH_VALUE = 114 + UNKNOWN_PSK_IDENTITY = 115 + CERTIFICATE_REQUIRED = 116 + NO_APPLICATION_PROTOCOL = 120 + + +class _TLSMessageType(_IntEnum): + """Message types (handshake protocol) + + See RFC 8446, section B.3 + """ + HELLO_REQUEST = 0 + CLIENT_HELLO = 1 + SERVER_HELLO = 2 + HELLO_VERIFY_REQUEST = 3 + NEWSESSION_TICKET = 4 + END_OF_EARLY_DATA = 5 + HELLO_RETRY_REQUEST = 6 + ENCRYPTED_EXTENSIONS = 8 + CERTIFICATE = 11 + SERVER_KEY_EXCHANGE = 12 + CERTIFICATE_REQUEST = 13 + SERVER_DONE = 14 + CERTIFICATE_VERIFY = 15 + CLIENT_KEY_EXCHANGE = 16 + FINISHED = 20 + CERTIFICATE_URL = 21 + CERTIFICATE_STATUS = 22 + SUPPLEMENTAL_DATA = 23 + KEY_UPDATE = 24 + NEXT_PROTO = 67 + MESSAGE_HASH = 254 + CHANGE_CIPHER_SPEC = 0x0101 + + if sys.platform == "win32": from _ssl import enum_certificates #, enum_crls from socket import socket, AF_INET, SOCK_STREAM, create_connection from socket import SOL_SOCKET, SO_TYPE +import socket as _socket import base64 # for DER-to-PEM translation import errno import warnings @@ -169,124 +262,121 @@ import warnings socket_error = OSError # keep that public name in module namespace -if _ssl.HAS_TLS_UNIQUE: - CHANNEL_BINDING_TYPES = ['tls-unique'] -else: - CHANNEL_BINDING_TYPES = [] +CHANNEL_BINDING_TYPES = ['tls-unique'] + +HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT') -# Disable weak or insecure ciphers by default -# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL') -# Enable a better set of ciphers by default -# This list has been explicitly chosen to: -# * TLS 1.3 ChaCha20 and AES-GCM cipher suites -# * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE) -# * Prefer ECDHE over DHE for better performance -# * Prefer AEAD over CBC for better performance and security -# * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI -# (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2) -# * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better -# performance and security -# * Then Use HIGH cipher suites as a fallback -# * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs -# for security reasons -_DEFAULT_CIPHERS = ( - 'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:' - 'TLS13-AES-128-GCM-SHA256:' - 'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:' - 'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:' - '!aNULL:!eNULL:!MD5:!3DES' - ) +_RESTRICTED_SERVER_CIPHERS = _DEFAULT_CIPHERS -# Restricted and more secure ciphers for the server side -# This list has been explicitly chosen to: -# * TLS 1.3 ChaCha20 and AES-GCM cipher suites -# * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE) -# * Prefer ECDHE over DHE for better performance -# * Prefer AEAD over CBC for better performance and security -# * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI -# * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better -# performance and security -# * Then Use HIGH cipher suites as a fallback -# * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and -# 3DES for security reasons -_RESTRICTED_SERVER_CIPHERS = ( - 'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:' - 'TLS13-AES-128-GCM-SHA256:' - 'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:' - 'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:' - '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES' -) +CertificateError = SSLCertVerificationError -class CertificateError(ValueError): - pass - - -def _dnsname_match(dn, hostname, max_wildcards=1): +def _dnsname_match(dn, hostname): """Matching according to RFC 6125, section 6.4.3 - http://tools.ietf.org/html/rfc6125#section-6.4.3 + - Hostnames are compared lower case. + - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE). + - Partial wildcards like 'www*.example.org', multiple wildcards, sole + wildcard or wildcards in labels other then the left-most label are not + supported and a CertificateError is raised. + - A wildcard must match at least one character. """ - pats = [] if not dn: return False - leftmost, *remainder = dn.split(r'.') - - wildcards = leftmost.count('*') - if wildcards > max_wildcards: - # Issue #17980: avoid denials of service by refusing more - # than one wildcard per fragment. A survey of established - # policy among SSL implementations showed it to be a - # reasonable choice. - raise CertificateError( - "too many wildcards in certificate DNS name: " + repr(dn)) - + wildcards = dn.count('*') # speed up common case w/o wildcards if not wildcards: return dn.lower() == hostname.lower() - # RFC 6125, section 6.4.3, subitem 1. - # The client SHOULD NOT attempt to match a presented identifier in which - # the wildcard character comprises a label other than the left-most label. - if leftmost == '*': - # When '*' is a fragment by itself, it matches a non-empty dotless - # fragment. - pats.append('[^.]+') - elif leftmost.startswith('xn--') or hostname.startswith('xn--'): - # RFC 6125, section 6.4.3, subitem 3. - # The client SHOULD NOT attempt to match a presented identifier - # where the wildcard character is embedded within an A-label or - # U-label of an internationalized domain name. - pats.append(re.escape(leftmost)) + if wildcards > 1: + raise CertificateError( + "too many wildcards in certificate DNS name: {!r}.".format(dn)) + + dn_leftmost, sep, dn_remainder = dn.partition('.') + + if '*' in dn_remainder: + # Only match wildcard in leftmost segment. + raise CertificateError( + "wildcard can only be present in the leftmost label: " + "{!r}.".format(dn)) + + if not sep: + # no right side + raise CertificateError( + "sole wildcard without additional labels are not support: " + "{!r}.".format(dn)) + + if dn_leftmost != '*': + # no partial wildcard matching + raise CertificateError( + "partial wildcards in leftmost label are not supported: " + "{!r}.".format(dn)) + + hostname_leftmost, sep, hostname_remainder = hostname.partition('.') + if not hostname_leftmost or not sep: + # wildcard must match at least one char + return False + return dn_remainder.lower() == hostname_remainder.lower() + + +def _inet_paton(ipname): + """Try to convert an IP address to packed binary form + + Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6 + support. + """ + # inet_aton() also accepts strings like '1', '127.1', some also trailing + # data like '127.0.0.1 whatever'. + try: + addr = _socket.inet_aton(ipname) + except OSError: + # not an IPv4 address + pass else: - # Otherwise, '*' matches any dotless string, e.g. www* - pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) + if _socket.inet_ntoa(addr) == ipname: + # only accept injective ipnames + return addr + else: + # refuse for short IPv4 notation and additional trailing data + raise ValueError( + "{!r} is not a quad-dotted IPv4 address.".format(ipname) + ) - # add the remaining fragments, ignore any wildcards - for frag in remainder: - pats.append(re.escape(frag)) + try: + return _socket.inet_pton(_socket.AF_INET6, ipname) + except OSError: + raise ValueError("{!r} is neither an IPv4 nor an IP6 " + "address.".format(ipname)) + except AttributeError: + # AF_INET6 not available + pass - pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) - return pat.match(hostname) + raise ValueError("{!r} is not an IPv4 address.".format(ipname)) -def _ipaddress_match(ipname, host_ip): +def _ipaddress_match(cert_ipaddress, host_ip): """Exact matching of IP addresses. RFC 6125 explicitly doesn't define an algorithm for this (section 1.7.2 - "Out of Scope"). """ - # OpenSSL may add a trailing newline to a subjectAltName's IP address - ip = ipaddress.ip_address(ipname.rstrip()) + # OpenSSL may add a trailing newline to a subjectAltName's IP address, + # commonly woth IPv6 addresses. Strip off trailing \n. + ip = _inet_paton(cert_ipaddress.rstrip()) return ip == host_ip def match_hostname(cert, hostname): """Verify that *cert* (in decoded format as returned by SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 - rules are followed, but IP addresses are not accepted for *hostname*. + rules are followed. + + The function matches IP addresses rather than dNSNames if hostname is a + valid ipaddress string. IPv4 addresses are supported on all platforms. + IPv6 addresses are supported on platforms with IPv6 support (AF_INET6 + and inet_pton). CertificateError is raised on failure. On success, the function returns nothing. @@ -296,7 +386,7 @@ def match_hostname(cert, hostname): "SSL socket or SSL context with either " "CERT_OPTIONAL or CERT_REQUIRED") try: - host_ip = ipaddress.ip_address(hostname) + host_ip = _inet_paton(hostname) except ValueError: # Not an IP address (common case) host_ip = None @@ -384,34 +474,48 @@ class Purpose(_ASN1Object, _Enum): class SSLContext(_SSLContext): """An SSLContext holds various SSL-related configuration options and data, such as certificates and possibly a private key.""" - - __slots__ = ('protocol', '__weakref__') _windows_cert_stores = ("CA", "ROOT") + sslsocket_class = None # SSLSocket is assigned later. + sslobject_class = None # SSLObject is assigned later. + def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs): self = _SSLContext.__new__(cls, protocol) - if protocol != _SSLv2_IF_EXISTS: - self.set_ciphers(_DEFAULT_CIPHERS) return self - def __init__(self, protocol=PROTOCOL_TLS): - self.protocol = protocol + def _encode_hostname(self, hostname): + if hostname is None: + return None + elif isinstance(hostname, str): + return hostname.encode('idna').decode('ascii') + else: + return hostname.decode('ascii') def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None, session=None): - return SSLSocket(sock=sock, server_side=server_side, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - server_hostname=server_hostname, - _context=self, _session=session) + # SSLSocket class handles server_hostname encoding before it calls + # ctx._wrap_socket() + return self.sslsocket_class._create( + sock=sock, + server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs, + server_hostname=server_hostname, + context=self, + session=session + ) def wrap_bio(self, incoming, outgoing, server_side=False, server_hostname=None, session=None): - sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side, - server_hostname=server_hostname) - return SSLObject(sslobj, session=session) + # Need to encode server_hostname here because _wrap_bio() can only + # handle ASCII str. + return self.sslobject_class._create( + incoming, outgoing, server_side=server_side, + server_hostname=self._encode_hostname(server_hostname), + session=session, context=self, + ) def set_npn_protocols(self, npn_protocols): protos = bytearray() @@ -424,6 +528,19 @@ class SSLContext(_SSLContext): self._set_npn_protocols(protos) + def set_servername_callback(self, server_name_callback): + if server_name_callback is None: + self.sni_callback = None + else: + if not callable(server_name_callback): + raise TypeError("not a callable object") + + def shim_cb(sslobj, servername, sslctx): + servername = self._encode_hostname(servername) + return server_name_callback(sslobj, servername, sslctx) + + self.sni_callback = shim_cb + def set_alpn_protocols(self, alpn_protocols): protos = bytearray() for protocol in alpn_protocols: @@ -457,6 +574,25 @@ class SSLContext(_SSLContext): self._load_windows_store_certs(storename, purpose) self.set_default_verify_paths() + if hasattr(_SSLContext, 'minimum_version'): + @property + def minimum_version(self): + return TLSVersion(super().minimum_version) + + @minimum_version.setter + def minimum_version(self, value): + if value == TLSVersion.SSLv3: + self.options &= ~Options.OP_NO_SSLv3 + super(SSLContext, SSLContext).minimum_version.__set__(self, value) + + @property + def maximum_version(self): + return TLSVersion(super().maximum_version) + + @maximum_version.setter + def maximum_version(self, value): + super(SSLContext, SSLContext).maximum_version.__set__(self, value) + @property def options(self): return Options(super().options) @@ -465,6 +601,104 @@ class SSLContext(_SSLContext): def options(self, value): super(SSLContext, SSLContext).options.__set__(self, value) + if hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT'): + @property + def hostname_checks_common_name(self): + ncs = self._host_flags & _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT + return ncs != _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT + + @hostname_checks_common_name.setter + def hostname_checks_common_name(self, value): + if value: + self._host_flags &= ~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT + else: + self._host_flags |= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT + else: + @property + def hostname_checks_common_name(self): + return True + + @property + def _msg_callback(self): + """TLS message callback + + The message callback provides a debugging hook to analyze TLS + connections. The callback is called for any TLS protocol message + (header, handshake, alert, and more), but not for application data. + Due to technical limitations, the callback can't be used to filter + traffic or to abort a connection. Any exception raised in the + callback is delayed until the handshake, read, or write operation + has been performed. + + def msg_cb(conn, direction, version, content_type, msg_type, data): + pass + + conn + :class:`SSLSocket` or :class:`SSLObject` instance + direction + ``read`` or ``write`` + version + :class:`TLSVersion` enum member or int for unknown version. For a + frame header, it's the header version. + content_type + :class:`_TLSContentType` enum member or int for unsupported + content type. + msg_type + Either a :class:`_TLSContentType` enum number for a header + message, a :class:`_TLSAlertType` enum member for an alert + message, a :class:`_TLSMessageType` enum member for other + messages, or int for unsupported message types. + data + Raw, decrypted message content as bytes + """ + inner = super()._msg_callback + if inner is not None: + return inner.user_function + else: + return None + + @_msg_callback.setter + def _msg_callback(self, callback): + if callback is None: + super(SSLContext, SSLContext)._msg_callback.__set__(self, None) + return + + if not hasattr(callback, '__call__'): + raise TypeError(f"{callback} is not callable.") + + def inner(conn, direction, version, content_type, msg_type, data): + try: + version = TLSVersion(version) + except ValueError: + pass + + try: + content_type = _TLSContentType(content_type) + except ValueError: + pass + + if content_type == _TLSContentType.HEADER: + msg_enum = _TLSContentType + elif content_type == _TLSContentType.ALERT: + msg_enum = _TLSAlertType + else: + msg_enum = _TLSMessageType + try: + msg_type = msg_enum(msg_type) + except ValueError: + pass + + return callback(conn, direction, version, + content_type, msg_type, data) + + inner.user_function = callback + + super(SSLContext, SSLContext)._msg_callback.__set__(self, inner) + + @property + def protocol(self): + return _SSLMethod(super().protocol) + @property def verify_flags(self): return VerifyFlags(super().verify_flags) @@ -506,8 +740,6 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, # verify certs and host name in client mode context.verify_mode = CERT_REQUIRED context.check_hostname = True - elif purpose == Purpose.CLIENT_AUTH: - context.set_ciphers(_RESTRICTED_SERVER_CIPHERS) if cafile or capath or cadata: context.load_verify_locations(cafile, capath, cadata) @@ -516,9 +748,14 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system # root CA certificates for the given purpose. This may fail silently. context.load_default_certs(purpose) + # OpenSSL 1.1.1 keylog file + if hasattr(context, 'keylog_filename'): + keylogfile = os.environ.get('SSLKEYLOGFILE') + if keylogfile and not sys.flags.ignore_environment: + context.keylog_filename = keylogfile return context -def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, +def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE, check_hostname=False, purpose=Purpose.SERVER_AUTH, certfile=None, keyfile=None, cafile=None, capath=None, cadata=None): @@ -537,9 +774,12 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, # by default. context = SSLContext(protocol) + if not check_hostname: + context.check_hostname = False if cert_reqs is not None: context.verify_mode = cert_reqs - context.check_hostname = check_hostname + if check_hostname: + context.check_hostname = True if keyfile and not certfile: raise ValueError("certfile must be specified") @@ -554,7 +794,11 @@ def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system # root CA certificates for the given purpose. This may fail silently. context.load_default_certs(purpose) - + # OpenSSL 1.1.1 keylog file + if hasattr(context, 'keylog_filename'): + keylogfile = os.environ.get('SSLKEYLOGFILE') + if keylogfile and not sys.flags.ignore_environment: + context.keylog_filename = keylogfile return context # Used by http.client if no context is explicitly passed. @@ -580,13 +824,23 @@ class SSLObject: * Any form of network IO, including methods such as ``recv`` and ``send``. * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. """ + def __init__(self, *args, **kwargs): + raise TypeError( + f"{self.__class__.__name__} does not have a public " + f"constructor. Instances are returned by SSLContext.wrap_bio()." + ) - def __init__(self, sslobj, owner=None, session=None): + @classmethod + def _create(cls, incoming, outgoing, server_side=False, + server_hostname=None, session=None, context=None): + self = cls.__new__(cls) + sslobj = context._wrap_bio( + incoming, outgoing, server_side=server_side, + server_hostname=server_hostname, + owner=self, session=session + ) self._sslobj = sslobj - # Note: _sslobj takes a weak reference to owner - self._sslobj.owner = owner or self - if session is not None: - self._sslobj.session = session + return self @property def context(self): @@ -619,7 +873,7 @@ class SSLObject: @property def server_hostname(self): """The currently set server hostname (for SNI), or ``None`` if no - server hostame is set.""" + server hostname is set.""" return self._sslobj.server_hostname def read(self, len=1024, buffer=None): @@ -649,7 +903,7 @@ class SSLObject: Return None if no certificate was provided, {} if a certificate was provided, but not validated. """ - return self._sslobj.peer_certificate(binary_form) + return self._sslobj.getpeercert(binary_form) def selected_npn_protocol(self): """Return the currently selected NPN protocol as a string, or ``None`` @@ -688,11 +942,6 @@ class SSLObject: def do_handshake(self): """Start the SSL/TLS handshake.""" self._sslobj.do_handshake() - if self.context.check_hostname: - if not self.server_hostname: - raise ValueError("check_hostname needs server_hostname " - "argument") - match_hostname(self.getpeercert(), self.server_hostname) def unwrap(self): """Start the SSL shutdown handshake.""" @@ -702,13 +951,7 @@ class SSLObject: """Get channel binding data for current connection. Raise ValueError if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake).""" - if cb_type not in CHANNEL_BINDING_TYPES: - raise ValueError("Unsupported channel binding type") - if cb_type != "tls-unique": - raise NotImplementedError( - "{0} channel binding type not implemented" - .format(cb_type)) - return self._sslobj.tls_unique_cb() + return self._sslobj.get_channel_binding(cb_type) def version(self): """Return a string identifying the protocol version used by the @@ -719,76 +962,57 @@ class SSLObject: return self._sslobj.verify_client_post_handshake() +def _sslcopydoc(func): + """Copy docstring from SSLObject to SSLSocket""" + func.__doc__ = getattr(SSLObject, func.__name__).__doc__ + return func + + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps the underlying OS socket in an SSL context when necessary, and - provides read and write methods over that channel.""" + provides read and write methods over that channel. """ - def __init__(self, sock=None, keyfile=None, certfile=None, - server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_TLS, ca_certs=None, - do_handshake_on_connect=True, - family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, - suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, - server_hostname=None, - _context=None, _session=None): + def __init__(self, *args, **kwargs): + raise TypeError( + f"{self.__class__.__name__} does not have a public " + f"constructor. Instances are returned by " + f"SSLContext.wrap_socket()." + ) - if _context: - self._context = _context - else: - if server_side and not certfile: - raise ValueError("certfile must be specified for server-side " - "operations") - if keyfile and not certfile: - raise ValueError("certfile must be specified") - if certfile and not keyfile: - keyfile = certfile - self._context = SSLContext(ssl_version) - self._context.verify_mode = cert_reqs - if ca_certs: - self._context.load_verify_locations(ca_certs) - if certfile: - self._context.load_cert_chain(certfile, keyfile) - if npn_protocols: - self._context.set_npn_protocols(npn_protocols) - if ciphers: - self._context.set_ciphers(ciphers) - self.keyfile = keyfile - self.certfile = certfile - self.cert_reqs = cert_reqs - self.ssl_version = ssl_version - self.ca_certs = ca_certs - self.ciphers = ciphers - # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get - # mixed in. + @classmethod + def _create(cls, sock, server_side=False, do_handshake_on_connect=True, + suppress_ragged_eofs=True, server_hostname=None, + context=None, session=None): if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM: raise NotImplementedError("only stream sockets are supported") if server_side: if server_hostname: raise ValueError("server_hostname can only be specified " "in client mode") - if _session is not None: + if session is not None: raise ValueError("session can only be specified in " "client mode") - if self._context.check_hostname and not server_hostname: + if context.check_hostname and not server_hostname: raise ValueError("check_hostname requires server_hostname") - self._session = _session + + kwargs = dict( + family=sock.family, type=sock.type, proto=sock.proto, + fileno=sock.fileno() + ) + self = cls.__new__(cls, **kwargs) + super(SSLSocket, self).__init__(**kwargs) + self.settimeout(sock.gettimeout()) + sock.detach() + + self._context = context + self._session = session + self._closed = False + self._sslobj = None self.server_side = server_side - self.server_hostname = server_hostname + self.server_hostname = context._encode_hostname(server_hostname) self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs - if sock is not None: - socket.__init__(self, - family=sock.family, - type=sock.type, - proto=sock.proto, - fileno=sock.fileno()) - self.settimeout(sock.gettimeout()) - sock.detach() - elif fileno is not None: - socket.__init__(self, fileno=fileno) - else: - socket.__init__(self, family=family, type=type, proto=proto) # See if we are connected try: @@ -800,28 +1024,27 @@ class SSLSocket(socket): else: connected = True - self._closed = False - self._sslobj = None self._connected = connected if connected: # create the SSL object try: - sslobj = self._context._wrap_socket(self, server_side, - server_hostname) - self._sslobj = SSLObject(sslobj, owner=self, - session=self._session) + self._sslobj = self._context._wrap_socket( + self, server_side, self.server_hostname, + owner=self, session=self._session, + ) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: # non-blocking raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets") self.do_handshake() - except (OSError, ValueError): self.close() raise + return self @property + @_sslcopydoc def context(self): return self._context @@ -831,8 +1054,8 @@ class SSLSocket(socket): self._sslobj.context = ctx @property + @_sslcopydoc def session(self): - """The SSLSession for client socket.""" if self._sslobj is not None: return self._sslobj.session @@ -843,8 +1066,8 @@ class SSLSocket(socket): self._sslobj.session = session @property + @_sslcopydoc def session_reused(self): - """Was the client session reused during handshake""" if self._sslobj is not None: return self._sslobj.session_reused @@ -869,10 +1092,13 @@ class SSLSocket(socket): Return zero-length string on EOF.""" self._checkClosed() - if not self._sslobj: + if self._sslobj is None: raise ValueError("Read on closed or unwrapped SSL socket.") try: - return self._sslobj.read(len, buffer) + if buffer is not None: + return self._sslobj.read(len, buffer) + else: + return self._sslobj.read(len) except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: if buffer is not None: @@ -887,74 +1113,76 @@ class SSLSocket(socket): number of bytes of DATA actually transmitted.""" self._checkClosed() - if not self._sslobj: + if self._sslobj is None: raise ValueError("Write on closed or unwrapped SSL socket.") return self._sslobj.write(data) + @_sslcopydoc def getpeercert(self, binary_form=False): - """Returns a formatted version of the data in the - certificate provided by the other end of the SSL channel. - Return None if no certificate was provided, {} if a - certificate was provided, but not validated.""" - self._checkClosed() self._check_connected() return self._sslobj.getpeercert(binary_form) + @_sslcopydoc def selected_npn_protocol(self): self._checkClosed() - if not self._sslobj or not _ssl.HAS_NPN: + if self._sslobj is None or not _ssl.HAS_NPN: return None else: return self._sslobj.selected_npn_protocol() + @_sslcopydoc def selected_alpn_protocol(self): self._checkClosed() - if not self._sslobj or not _ssl.HAS_ALPN: + if self._sslobj is None or not _ssl.HAS_ALPN: return None else: return self._sslobj.selected_alpn_protocol() + @_sslcopydoc def cipher(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None else: return self._sslobj.cipher() + @_sslcopydoc def shared_ciphers(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None - return self._sslobj.shared_ciphers() + else: + return self._sslobj.shared_ciphers() + @_sslcopydoc def compression(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None else: return self._sslobj.compression() def send(self, data, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to send() on %s" % self.__class__) return self._sslobj.write(data) else: - return socket.send(self, data, flags) + return super().send(data, flags) def sendto(self, data, flags_or_addr, addr=None): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("sendto not allowed on instances of %s" % self.__class__) elif addr is None: - return socket.sendto(self, data, flags_or_addr) + return super().sendto(data, flags_or_addr) else: - return socket.sendto(self, data, flags_or_addr, addr) + return super().sendto(data, flags_or_addr, addr) def sendmsg(self, *args, **kwargs): # Ensure programs don't send data unencrypted if they try to @@ -964,42 +1192,40 @@ class SSLSocket(socket): def sendall(self, data, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to sendall() on %s" % self.__class__) count = 0 - # with memoryview(data) as view, view.cast("B") as byte_view: - # XXX RustPython TODO: proper memoryview implementation - byte_view = data - amount = len(byte_view) - while count < amount: - v = self.send(byte_view[count:]) - count += v + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v else: - return socket.sendall(self, data, flags) + return super().sendall(data, flags) def sendfile(self, file, offset=0, count=None): """Send a file, possibly by using os.sendfile() if this is a clear-text socket. Return the total number of bytes sent. """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sendfile_use_send(file, offset, count) + else: # os.sendfile() works with plain sockets only return super().sendfile(file, offset, count) - else: - return self._sendfile_use_send(file, offset, count) def recv(self, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to recv() on %s" % self.__class__) return self.read(buflen) else: - return socket.recv(self, buflen, flags) + return super().recv(buflen, flags) def recv_into(self, buffer, nbytes=None, flags=0): self._checkClosed() @@ -1007,30 +1233,30 @@ class SSLSocket(socket): nbytes = len(buffer) elif nbytes is None: nbytes = 1024 - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to recv_into() on %s" % self.__class__) return self.read(nbytes, buffer) else: - return socket.recv_into(self, buffer, nbytes, flags) + return super().recv_into(buffer, nbytes, flags) def recvfrom(self, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) else: - return socket.recvfrom(self, buflen, flags) + return super().recvfrom(buflen, flags) def recvfrom_into(self, buffer, nbytes=None, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("recvfrom_into not allowed on instances of %s" % self.__class__) else: - return socket.recvfrom_into(self, buffer, nbytes, flags) + return super().recvfrom_into(buffer, nbytes, flags) def recvmsg(self, *args, **kwargs): raise NotImplementedError("recvmsg not allowed on instances of %s" % @@ -1040,9 +1266,10 @@ class SSLSocket(socket): raise NotImplementedError("recvmsg_into not allowed on instances of " "%s" % self.__class__) + @_sslcopydoc def pending(self): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: return self._sslobj.pending() else: return 0 @@ -1050,16 +1277,18 @@ class SSLSocket(socket): def shutdown(self, how): self._checkClosed() self._sslobj = None - socket.shutdown(self, how) + super().shutdown(how) + @_sslcopydoc def unwrap(self): if self._sslobj: - s = self._sslobj.unwrap() + s = self._sslobj.shutdown() self._sslobj = None return s else: raise ValueError("No SSL wrapper around " + str(self)) + @_sslcopydoc def verify_client_post_handshake(self): if self._sslobj: return self._sslobj.verify_client_post_handshake() @@ -1068,10 +1297,10 @@ class SSLSocket(socket): def _real_close(self): self._sslobj = None - socket._real_close(self) + super()._real_close() + @_sslcopydoc def do_handshake(self, block=False): - """Perform a TLS/SSL handshake.""" self._check_connected() timeout = self.gettimeout() try: @@ -1086,17 +1315,18 @@ class SSLSocket(socket): raise ValueError("can't connect in server-side mode") # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. - if self._connected: + if self._connected or self._sslobj is not None: raise ValueError("attempt to connect already-connected SSLSocket!") - sslobj = self.context._wrap_socket(self, False, self.server_hostname) - self._sslobj = SSLObject(sslobj, owner=self, - session=self._session) + self._sslobj = self.context._wrap_socket( + self, False, self.server_hostname, + owner=self, session=self._session + ) try: if connect_ex: - rc = socket.connect_ex(self, addr) + rc = super().connect_ex(addr) else: rc = None - socket.connect(self, addr) + super().connect(addr) if not rc: self._connected = True if self.do_handshake_on_connect: @@ -1121,30 +1351,35 @@ class SSLSocket(socket): a tuple containing that new connection wrapped with a server-side SSL channel, and the address of the remote client.""" - newsock, addr = socket.accept(self) + newsock, addr = super().accept() newsock = self.context.wrap_socket(newsock, do_handshake_on_connect=self.do_handshake_on_connect, suppress_ragged_eofs=self.suppress_ragged_eofs, server_side=True) return newsock, addr + @_sslcopydoc def get_channel_binding(self, cb_type="tls-unique"): - """Get channel binding data for current connection. Raise ValueError - if the requested `cb_type` is not supported. Return bytes of the data - or None if the data is not available (e.g. before the handshake). - """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sslobj.get_channel_binding(cb_type) + else: + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError( + "{0} channel binding type not implemented".format(cb_type) + ) return None - return self._sslobj.get_channel_binding(cb_type) + @_sslcopydoc def version(self): - """ - Return a string identifying the protocol version used by the - current SSL channel, or None if there is no established channel. - """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sslobj.version() + else: return None - return self._sslobj.version() + + +# Python does not support forward declaration of types. +SSLContext.sslsocket_class = SSLSocket +SSLContext.sslobject_class = SSLObject def wrap_socket(sock, keyfile=None, certfile=None, @@ -1153,12 +1388,25 @@ def wrap_socket(sock, keyfile=None, certfile=None, do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None): - return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile, - server_side=server_side, cert_reqs=cert_reqs, - ssl_version=ssl_version, ca_certs=ca_certs, - do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs, - ciphers=ciphers) + + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") + if keyfile and not certfile: + raise ValueError("certfile must be specified") + context = SSLContext(ssl_version) + context.verify_mode = cert_reqs + if ca_certs: + context.load_verify_locations(ca_certs) + if certfile: + context.load_cert_chain(certfile, keyfile) + if ciphers: + context.set_ciphers(ciphers) + return context.wrap_socket( + sock=sock, server_side=server_side, + do_handshake_on_connect=do_handshake_on_connect, + suppress_ragged_eofs=suppress_ragged_eofs + ) # some utility functions @@ -1200,9 +1448,10 @@ def DER_cert_to_PEM_cert(der_cert_bytes): PEM version of it as a string.""" f = str(base64.standard_b64encode(der_cert_bytes), 'ASCII', 'strict') - return (PEM_HEADER + '\n' + - textwrap.fill(f, 64) + '\n' + - PEM_FOOTER + '\n') + ss = [PEM_HEADER] + ss += [f[i:i+64] for i in range(0, len(f), 64)] + ss.append(PEM_FOOTER + '\n') + return '\n'.join(ss) def PEM_cert_to_DER_cert(pem_cert_string): """Takes a certificate in ASCII PEM format and returns the