Files
RustPython/crates/stdlib/src/ssl.rs
2026-02-08 16:49:18 +09:00

5168 lines
207 KiB
Rust

// spell-checker: ignore ssleof aesccm aesgcm capath getblocking setblocking ENDTLS TLSEXT
//! Pure Rust SSL/TLS implementation using rustls
//!
//! This module provides SSL/TLS support without requiring C dependencies.
//! It implements the Python ssl module API using:
//! - rustls: TLS protocol implementation
//! - x509-parser/x509-cert: Certificate parsing
//! - ring: Cryptographic primitives
//! - rustls-platform-verifier: Platform-native certificate verification
//!
//! DO NOT add openssl dependency here.
//!
//! Warning: This library contains AI-generated code and comments. Do not trust any code or comment without verification. Please have a qualified expert review the code and remove this notice after review.
// OID (Object Identifier) management module
mod oid;
// Certificate operations module (parsing, validation, conversion)
mod cert;
// OpenSSL compatibility layer (abstracts rustls operations)
mod compat;
// SSL exception types (shared with openssl backend)
mod error;
pub(crate) use _ssl::module_def;
#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
#[pymodule(with(error::ssl_error))]
mod _ssl {
use crate::{
common::{
hash::PyHash,
lock::{PyMutex, PyRwLock},
},
socket::{PySocket, SelectKind, sock_select, timeout_error_msg},
vm::{
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
VirtualMachine,
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef},
convert::IntoPyException,
function::{
ArgBytesLike, ArgMemoryBuffer, Either, FuncArgs, OptionalArg, PyComparisonValue,
},
stdlib::warnings,
types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable},
},
};
// Import error types used in this module (others are exposed via pymodule(with(...)))
use super::error::{
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
create_ssl_zero_return_error,
};
use alloc::sync::Arc;
use core::{
sync::atomic::{AtomicUsize, Ordering},
time::Duration,
};
use std::{collections::HashMap, time::SystemTime};
// Rustls imports
use parking_lot::{Mutex as ParkingMutex, RwLock as ParkingRwLock};
use pem_rfc7468::{LineEnding, encode_string};
use rustls::{
ClientConfig, ClientConnection, RootCertStore, ServerConfig, ServerConnection,
client::{ClientSessionMemoryCache, ClientSessionStore},
crypto::SupportedKxGroup,
pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName},
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
version::{TLS12, TLS13},
};
use sha2::{Digest, Sha256};
// Import certificate operations module
use super::cert;
// Import OID module
use super::oid;
// Import compat module (OpenSSL compatibility layer)
use super::compat::{
ClientConfigOptions, MultiCertResolver, ProtocolSettings, ServerConfigOptions, SslError,
TlsConnection, create_client_config, create_server_config, curve_name_to_kx_group,
extract_cipher_info, get_cipher_encryption_desc, is_blocking_io_error,
normalize_cipher_name, ssl_do_handshake,
};
// Type aliases for better readability
// Additional type alias for certificate/key pairs (SessionCache and SniCertName defined below)
/// Certificate and private key pair used in SSL contexts
type CertKeyPair = (Arc<CertifiedKey>, PrivateKeyDer<'static>);
// Constants matching Python ssl module
// SSL/TLS Protocol versions
#[pyattr]
const PROTOCOL_TLS: i32 = 2; // Auto-negotiate best version
#[pyattr]
const PROTOCOL_SSLv23: i32 = PROTOCOL_TLS; // Alias for PROTOCOL_TLS
#[pyattr]
const PROTOCOL_TLS_CLIENT: i32 = 16;
#[pyattr]
const PROTOCOL_TLS_SERVER: i32 = 17;
// Note: rustls doesn't support TLS 1.0/1.1 for security reasons
// These are defined for API compatibility but will raise errors if used
#[pyattr]
const PROTOCOL_TLSv1: i32 = 3;
#[pyattr]
const PROTOCOL_TLSv1_1: i32 = 4;
#[pyattr]
const PROTOCOL_TLSv1_2: i32 = 5;
#[pyattr]
const PROTOCOL_TLSv1_3: i32 = 6;
// Protocol version constants for TLSVersion enum
#[pyattr]
const PROTO_SSLv3: i32 = 0x0300;
#[pyattr]
const PROTO_TLSv1: i32 = 0x0301;
#[pyattr]
const PROTO_TLSv1_1: i32 = 0x0302;
#[pyattr]
const PROTO_TLSv1_2: i32 = 0x0303;
#[pyattr]
const PROTO_TLSv1_3: i32 = 0x0304;
// Minimum and maximum supported protocol versions for rustls
// Use special values -2 and -1 to avoid enum name conflicts
#[pyattr]
const PROTO_MINIMUM_SUPPORTED: i32 = -2; // special value
#[pyattr]
const PROTO_MAXIMUM_SUPPORTED: i32 = -1; // special value
// Internal constants for rustls actual supported versions
// rustls only supports TLS 1.2 and TLS 1.3
const MINIMUM_VERSION: i32 = PROTO_TLSv1_2; // 0x0303
const MAXIMUM_VERSION: i32 = PROTO_TLSv1_3; // 0x0304
// Buffer sizes and limits (OpenSSL/CPython compatibility)
const PEM_BUFSIZE: usize = 1024;
// OpenSSL: ssl/ssl_local.h
const SSL3_RT_MAX_PLAIN_LENGTH: usize = 16384;
// SSL session cache size (common practice, similar to OpenSSL defaults)
const SSL_SESSION_CACHE_SIZE: usize = 256;
// Certificate verification modes
#[pyattr]
const CERT_NONE: i32 = 0;
#[pyattr]
const CERT_OPTIONAL: i32 = 1;
#[pyattr]
const CERT_REQUIRED: i32 = 2;
// Certificate requirements
#[pyattr]
const VERIFY_DEFAULT: i32 = 0;
#[pyattr]
const VERIFY_CRL_CHECK_LEAF: i32 = 4;
#[pyattr]
const VERIFY_CRL_CHECK_CHAIN: i32 = 12;
#[pyattr]
const VERIFY_X509_STRICT: i32 = 32;
#[pyattr]
const VERIFY_ALLOW_PROXY_CERTS: i32 = 64;
#[pyattr]
const VERIFY_X509_TRUSTED_FIRST: i32 = 32768;
#[pyattr]
const VERIFY_X509_PARTIAL_CHAIN: i32 = 0x80000;
// Options (OpenSSL-compatible flags, mostly no-op in rustls)
#[pyattr]
const OP_NO_SSLv2: i32 = 0x00000000; // Not supported anyway
#[pyattr]
const OP_NO_SSLv3: i32 = 0x02000000;
#[pyattr]
const OP_NO_TLSv1: i32 = 0x04000000;
#[pyattr]
const OP_NO_TLSv1_1: i32 = 0x10000000;
#[pyattr]
const OP_NO_TLSv1_2: i32 = 0x08000000;
#[pyattr]
const OP_NO_TLSv1_3: i32 = 0x20000000;
#[pyattr]
const OP_NO_COMPRESSION: i32 = 0x00020000;
#[pyattr]
const OP_CIPHER_SERVER_PREFERENCE: i32 = 0x00400000;
#[pyattr]
const OP_SINGLE_DH_USE: i32 = 0x00000000; // No-op in rustls
#[pyattr]
const OP_SINGLE_ECDH_USE: i32 = 0x00000000; // No-op in rustls
#[pyattr]
const OP_NO_TICKET: i32 = 0x00004000;
#[pyattr]
const OP_LEGACY_SERVER_CONNECT: i32 = 0x00000004;
#[pyattr]
const OP_NO_RENEGOTIATION: i32 = 0x40000000;
#[pyattr]
const OP_IGNORE_UNEXPECTED_EOF: i32 = 0x00000080;
#[pyattr]
const OP_ENABLE_MIDDLEBOX_COMPAT: i32 = 0x00100000;
#[pyattr]
const OP_ALL: i32 = 0x00000BFB; // Combined "safe" options (reduced for i32, excluding OP_LEGACY_SERVER_CONNECT for OpenSSL 3.0.0+ compatibility)
// Alert types (matching _TLSAlertType enum)
#[pyattr]
const ALERT_DESCRIPTION_CLOSE_NOTIFY: i32 = 0;
#[pyattr]
const ALERT_DESCRIPTION_UNEXPECTED_MESSAGE: i32 = 10;
#[pyattr]
const ALERT_DESCRIPTION_BAD_RECORD_MAC: i32 = 20;
#[pyattr]
const ALERT_DESCRIPTION_DECRYPTION_FAILED: i32 = 21;
#[pyattr]
const ALERT_DESCRIPTION_RECORD_OVERFLOW: i32 = 22;
#[pyattr]
const ALERT_DESCRIPTION_DECOMPRESSION_FAILURE: i32 = 30;
#[pyattr]
const ALERT_DESCRIPTION_HANDSHAKE_FAILURE: i32 = 40;
#[pyattr]
const ALERT_DESCRIPTION_NO_CERTIFICATE: i32 = 41;
#[pyattr]
const ALERT_DESCRIPTION_BAD_CERTIFICATE: i32 = 42;
#[pyattr]
const ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE: i32 = 43;
#[pyattr]
const ALERT_DESCRIPTION_CERTIFICATE_REVOKED: i32 = 44;
#[pyattr]
const ALERT_DESCRIPTION_CERTIFICATE_EXPIRED: i32 = 45;
#[pyattr]
const ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN: i32 = 46;
#[pyattr]
const ALERT_DESCRIPTION_ILLEGAL_PARAMETER: i32 = 47;
#[pyattr]
const ALERT_DESCRIPTION_UNKNOWN_CA: i32 = 48;
#[pyattr]
const ALERT_DESCRIPTION_ACCESS_DENIED: i32 = 49;
#[pyattr]
const ALERT_DESCRIPTION_DECODE_ERROR: i32 = 50;
#[pyattr]
const ALERT_DESCRIPTION_DECRYPT_ERROR: i32 = 51;
#[pyattr]
const ALERT_DESCRIPTION_EXPORT_RESTRICTION: i32 = 60;
#[pyattr]
const ALERT_DESCRIPTION_PROTOCOL_VERSION: i32 = 70;
#[pyattr]
const ALERT_DESCRIPTION_INSUFFICIENT_SECURITY: i32 = 71;
#[pyattr]
const ALERT_DESCRIPTION_INTERNAL_ERROR: i32 = 80;
#[pyattr]
const ALERT_DESCRIPTION_INAPPROPRIATE_FALLBACK: i32 = 86;
#[pyattr]
const ALERT_DESCRIPTION_USER_CANCELLED: i32 = 90;
#[pyattr]
const ALERT_DESCRIPTION_NO_RENEGOTIATION: i32 = 100;
#[pyattr]
const ALERT_DESCRIPTION_MISSING_EXTENSION: i32 = 109;
#[pyattr]
const ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION: i32 = 110;
#[pyattr]
const ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE: i32 = 111;
#[pyattr]
const ALERT_DESCRIPTION_UNRECOGNIZED_NAME: i32 = 112;
#[pyattr]
const ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE: i32 = 113;
#[pyattr]
const ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE: i32 = 114;
#[pyattr]
const ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY: i32 = 115;
#[pyattr]
const ALERT_DESCRIPTION_CERTIFICATE_REQUIRED: i32 = 116;
#[pyattr]
const ALERT_DESCRIPTION_NO_APPLICATION_PROTOCOL: i32 = 120;
// Version info - reporting as OpenSSL 3.3.0 for compatibility
#[pyattr]
const OPENSSL_VERSION_NUMBER: i32 = 0x30300000; // OpenSSL 3.3.0 (808452096)
#[pyattr]
const OPENSSL_VERSION: &str = "OpenSSL 3.3.0 (rustls/0.23)";
#[pyattr]
const OPENSSL_VERSION_INFO: (i32, i32, i32, i32, i32) = (3, 3, 0, 0, 15); // 3.3.0 release
#[pyattr]
const _OPENSSL_API_VERSION: (i32, i32, i32, i32, i32) = (3, 3, 0, 0, 15); // 3.3.0 release
// Default cipher list for rustls - using modern secure ciphers
#[pyattr]
const _DEFAULT_CIPHERS: &str =
"TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256:TLS_CHACHA20_POLY1305_SHA256";
// Has features
#[pyattr]
const HAS_SNI: bool = true;
#[pyattr]
const HAS_TLS_UNIQUE: bool = false; // Not supported
#[pyattr]
const HAS_ECDH: bool = true;
#[pyattr]
const HAS_NPN: bool = false; // Deprecated, use ALPN
#[pyattr]
const HAS_ALPN: bool = true;
#[pyattr]
const HAS_PSK: bool = false; // PSK not supported in rustls
#[pyattr]
const HAS_SSLv2: bool = false;
#[pyattr]
const HAS_SSLv3: bool = false;
#[pyattr]
const HAS_TLSv1: bool = false; // Not supported for security
#[pyattr]
const HAS_TLSv1_1: bool = false; // Not supported for security
#[pyattr]
const HAS_TLSv1_2: bool = true; // rustls supports TLS 1.2
#[pyattr]
const HAS_TLSv1_3: bool = true;
// Encoding constants (matching OpenSSL)
#[pyattr]
const ENCODING_PEM: i32 = 1;
#[pyattr]
const ENCODING_DER: i32 = 2;
#[pyattr]
const ENCODING_PEM_AUX: i32 = 0x101; // PEM + 0x100
/// Validate server hostname for TLS SNI
///
/// Checks that the hostname:
/// - Is not empty
/// - Does not start with a dot
/// - Is not an IP address (SNI requires DNS names)
/// - Does not contain null bytes
/// - Does not exceed 253 characters (DNS limit)
///
/// Returns Ok(()) if validation passes, or an appropriate error.
fn validate_hostname(hostname: &str, vm: &VirtualMachine) -> PyResult<()> {
if hostname.is_empty() {
return Err(vm.new_value_error("server_hostname cannot be an empty string"));
}
if hostname.starts_with('.') {
return Err(vm.new_value_error("server_hostname cannot start with a dot"));
}
// IP addresses are allowed as server_hostname
// SNI will not be sent for IP addresses
if hostname.contains('\0') {
return Err(vm.new_type_error("embedded null character"));
}
if hostname.len() > 253 {
return Err(vm.new_value_error("server_hostname is too long (maximum 253 characters)"));
}
Ok(())
}
// SNI certificate resolver that uses shared mutable state
// The Python SNI callback updates this state, and resolve() reads from it
#[derive(Debug)]
struct SniCertResolver {
// SNI state: (certificate, server_name)
sni_state: Arc<ParkingMutex<SniCertName>>,
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let mut state = self.sni_state.lock();
// Extract and store SNI from client hello for later use
if let Some(sni) = client_hello.server_name() {
state.1 = Some(sni.to_string());
} else {
state.1 = None;
}
// Return the current certificate (may have been updated by Python callback)
Some(state.0.clone())
}
}
// Session data structure for tracking TLS sessions
#[derive(Debug, Clone)]
struct SessionData {
#[allow(dead_code)]
server_name: String,
session_id: Vec<u8>,
creation_time: SystemTime,
lifetime: u64,
}
// Type alias to simplify complex session cache type
type SessionCache = Arc<ParkingRwLock<HashMap<Vec<u8>, Arc<ParkingMutex<SessionData>>>>>;
// Type alias for SNI state
type SniCertName = (Arc<CertifiedKey>, Option<String>);
// SESSION EMULATION IMPLEMENTATION
//
// IMPORTANT: This is an EMULATION of CPython's SSL session management.
// Rustls 0.23 does NOT expose session data (ticket bytes, session IDs, etc.)
// through public APIs. All session value fields are private.
//
// LIMITATIONS:
// - Session IDs are generated from metadata (server name + timestamp hash)
// NOT actual TLS session IDs
// - Ticket data is not stored (Rustls keeps it internally)
// - Session resumption works (via Rustls's automatic mechanism)
// but we can't access the actual session state
//
// This implementation provides:
// ✓ session.id - synthetic ID based on metadata
// ✓ session.time - creation timestamp
// ✓ session.timeout - default lifetime value
// ✓ session.has_ticket - always True when session exists
// ✓ session_reused - tracked via handshake_kind()
// ✗ Actual TLS session ID/ticket data - NOT ACCESSIBLE
// Generate a synthetic session ID from server name and timestamp
// NOTE: This is NOT the actual TLS session ID, just a unique identifier
fn generate_session_id_from_metadata(server_name: &str, time: &SystemTime) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(server_name.as_bytes());
hasher.update(
time.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
.to_le_bytes(),
);
hasher.finalize()[..16].to_vec()
}
// Custom ClientSessionStore that tracks session metadata for Python access
// NOTE: This wraps ClientSessionMemoryCache and records metadata when sessions are stored
#[derive(Debug)]
struct PythonClientSessionStore {
inner: Arc<ClientSessionMemoryCache>,
session_cache: SessionCache,
}
impl ClientSessionStore for PythonClientSessionStore {
fn set_kx_hint(&self, server_name: ServerName<'static>, group: rustls::NamedGroup) {
self.inner.set_kx_hint(server_name, group);
}
fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<rustls::NamedGroup> {
self.inner.kx_hint(server_name)
}
fn set_tls12_session(
&self,
server_name: ServerName<'static>,
value: rustls::client::Tls12ClientSessionValue,
) {
// Store in inner cache for actual resumption (Rustls handles this)
self.inner.set_tls12_session(server_name.clone(), value);
// Record metadata in Python-accessible cache
// NOTE: We can't access value.session_id or value.ticket (private fields)
// So we generate a synthetic ID from metadata
let creation_time = SystemTime::now();
let server_name_str = server_name.to_str();
let session_data = SessionData {
server_name: server_name_str.as_ref().to_string(),
session_id: generate_session_id_from_metadata(
server_name_str.as_ref(),
&creation_time,
),
creation_time,
lifetime: 7200, // TLS 1.2 default session lifetime
};
let key = server_name_str.as_bytes().to_vec();
self.session_cache
.write()
.insert(key, Arc::new(ParkingMutex::new(session_data)));
}
fn tls12_session(
&self,
server_name: &ServerName<'_>,
) -> Option<rustls::client::Tls12ClientSessionValue> {
self.inner.tls12_session(server_name)
}
fn remove_tls12_session(&self, server_name: &ServerName<'static>) {
self.inner.remove_tls12_session(server_name);
// Also remove from Python cache
let key = server_name.to_str().as_bytes().to_vec();
self.session_cache.write().remove(&key);
}
fn insert_tls13_ticket(
&self,
server_name: ServerName<'static>,
value: rustls::client::Tls13ClientSessionValue,
) {
// Store in inner cache for actual resumption (Rustls handles this)
self.inner.insert_tls13_ticket(server_name.clone(), value);
// Record metadata in Python-accessible cache
// NOTE: We can't access value.ticket or value.lifetime_secs (private fields)
// So we use default values
let creation_time = SystemTime::now();
let server_name_str = server_name.to_str();
let session_data = SessionData {
server_name: server_name_str.to_string(),
session_id: generate_session_id_from_metadata(
server_name_str.as_ref(),
&creation_time,
),
creation_time,
lifetime: 7200, // Default TLS 1.3 ticket lifetime (Rustls uses this)
};
let key = server_name_str.as_bytes().to_vec();
self.session_cache
.write()
.insert(key, Arc::new(ParkingMutex::new(session_data)));
}
fn take_tls13_ticket(
&self,
server_name: &ServerName<'static>,
) -> Option<rustls::client::Tls13ClientSessionValue> {
self.inner.take_tls13_ticket(server_name)
}
}
/// Parse length-prefixed ALPN protocol list
///
/// Format: [len1, proto1..., len2, proto2..., ...]
///
/// This is the wire format used by Python's ssl.py when calling _set_alpn_protocols().
/// Each protocol is prefixed with a single byte indicating its length.
///
/// # Arguments
/// * `bytes` - The length-prefixed protocol data
/// * `vm` - VirtualMachine for error creation
///
/// # Returns
/// * `Ok(Vec<Vec<u8>>)` - List of protocol names as byte vectors
/// * `Err(PyBaseExceptionRef)` - ValueError with detailed error message
fn parse_length_prefixed_alpn(bytes: &[u8], vm: &VirtualMachine) -> PyResult<Vec<Vec<u8>>> {
let mut alpn_list = Vec::new();
let mut offset = 0;
while offset < bytes.len() {
// Check if we can read the length byte
if offset + 1 > bytes.len() {
return Err(vm.new_value_error(format!(
"Invalid ALPN protocol data: unexpected end at offset {offset}",
)));
}
let proto_len = bytes[offset] as usize;
offset += 1;
// Validate protocol length
if proto_len == 0 {
return Err(vm.new_value_error(format!(
"Invalid ALPN protocol data: protocol length cannot be 0 at offset {}",
offset - 1
)));
}
// Check if we have enough bytes for the protocol data
if offset + proto_len > bytes.len() {
return Err(vm.new_value_error(format!(
"Invalid ALPN protocol data: expected {} bytes at offset {}, but only {} bytes remain",
proto_len, offset, bytes.len() - offset
)));
}
// Extract protocol bytes
let proto = bytes[offset..offset + proto_len].to_vec();
alpn_list.push(proto);
offset += proto_len;
}
Ok(alpn_list)
}
/// Parse OpenSSL cipher string to rustls SupportedCipherSuite list
///
/// Supports patterns like:
/// - "AES128" → filters for AES_128
/// - "AES256" → filters for AES_256
/// - "AES128:AES256" → both
/// - "ECDHE+AESGCM" → ECDHE AND AESGCM (both conditions must match)
/// - "ALL" or "DEFAULT" → all available
/// - "!MD5" → exclusion (ignored, rustls doesn't support weak ciphers anyway)
fn parse_cipher_string(cipher_str: &str) -> Result<Vec<rustls::SupportedCipherSuite>, String> {
use rustls::crypto::aws_lc_rs::ALL_CIPHER_SUITES;
if cipher_str.is_empty() {
return Err("No cipher can be selected".to_string());
}
let all_suites = ALL_CIPHER_SUITES;
let mut selected = Vec::new();
for part in cipher_str.split(':') {
let part = part.trim();
// Skip exclusions (rustls doesn't support these)
if part.starts_with('!') {
continue;
}
// Skip priority markers starting with +
if part.starts_with('+') {
continue;
}
// Match pattern
match part {
"ALL" | "DEFAULT" | "HIGH" => {
// Add all available cipher suites
selected.extend_from_slice(all_suites);
}
_ => {
// Check if this is a compound pattern with + (AND condition)
// e.g., "ECDHE+AESGCM" means ECDHE AND AESGCM
let patterns: Vec<&str> = part.split('+').collect();
let mut found_any = false;
for suite in all_suites {
let name = format!("{:?}", suite.suite());
// Check if all patterns match (AND condition)
let matches = patterns.iter().all(|&pattern| {
// Handle common OpenSSL pattern variations
if pattern.contains("AES128") {
name.contains("AES_128")
} else if pattern.contains("AES256") {
name.contains("AES_256")
} else if pattern == "AESGCM" {
// AESGCM: AES with GCM mode
name.contains("AES") && name.contains("GCM")
} else if pattern == "AESCCM" {
// AESCCM: AES with CCM mode
name.contains("AES") && name.contains("CCM")
} else if pattern == "CHACHA20" {
name.contains("CHACHA20")
} else if pattern == "ECDHE" {
name.contains("ECDHE")
} else if pattern == "DHE" {
// DHE but not ECDHE
name.contains("DHE") && !name.contains("ECDHE")
} else if pattern == "ECDH" {
// ECDH but not ECDHE
name.contains("ECDH") && !name.contains("ECDHE")
} else if pattern == "DH" {
// DH but not DHE or ECDH
name.contains("DH")
&& !name.contains("DHE")
&& !name.contains("ECDH")
} else if pattern == "RSA" {
name.contains("RSA")
} else if pattern == "AES" {
name.contains("AES")
} else if pattern == "ECDSA" {
name.contains("ECDSA")
} else {
// Direct substring match for other patterns
name.contains(pattern)
}
});
if matches {
selected.push(*suite);
found_any = true;
}
}
if !found_any {
// No matching cipher suite found - warn but continue
}
}
}
}
// Remove duplicates
selected.dedup_by_key(|s| s.suite());
if selected.is_empty() {
Err("No cipher can be selected".to_string())
} else {
Ok(selected)
}
}
// SSLContext - manages TLS configuration
#[pyattr]
#[pyclass(name = "_SSLContext", module = "ssl", traverse)]
#[derive(Debug, PyPayload)]
struct PySSLContext {
#[pytraverse(skip)]
protocol: i32,
#[pytraverse(skip)]
check_hostname: PyRwLock<bool>,
#[pytraverse(skip)]
verify_mode: PyRwLock<i32>,
#[pytraverse(skip)]
verify_flags: PyRwLock<i32>,
// Rustls configuration (built lazily)
#[allow(dead_code)]
#[pytraverse(skip)]
client_config: PyRwLock<Option<Arc<ClientConfig>>>,
#[allow(dead_code)]
#[pytraverse(skip)]
server_config: PyRwLock<Option<Arc<ServerConfig>>>,
// Certificate store
#[pytraverse(skip)]
root_certs: PyRwLock<RootCertStore>,
// Store full CA certificates for get_ca_certs()
// RootCertStore only keeps TrustAnchors, not full certificates
#[pytraverse(skip)]
ca_certs_der: PyRwLock<Vec<Vec<u8>>>,
// Store CA certificates from capath for lazy loading simulation
// (CPython only returns these in get_ca_certs() after they're used in handshake)
#[pytraverse(skip)]
capath_certs_der: PyRwLock<Vec<Vec<u8>>>,
// Certificate Revocation Lists for CRL checking
#[pytraverse(skip)]
crls: PyRwLock<Vec<CertificateRevocationListDer<'static>>>,
// Server certificate/key pairs (supports multiple for RSA+ECC dual mode)
// OpenSSL allows multiple cert/key pairs to be loaded, and selects the appropriate
// one based on client capabilities during handshake
// Stored as (CertifiedKey, PrivateKeyDer) to support both server and client usage
#[pytraverse(skip)]
cert_keys: PyRwLock<Vec<CertKeyPair>>,
// Options
#[allow(dead_code)]
#[pytraverse(skip)]
options: PyRwLock<i32>,
// ALPN protocols
#[allow(dead_code)]
#[pytraverse(skip)]
alpn_protocols: PyRwLock<Vec<Vec<u8>>>,
// ALPN strict matching flag
// When false (default), mimics OpenSSL behavior: no ALPN negotiation failure
// When true, requires ALPN match (Rustls default behavior)
#[allow(dead_code)]
#[pytraverse(skip)]
require_alpn_match: PyRwLock<bool>,
// TLS 1.3 features
#[pytraverse(skip)]
post_handshake_auth: PyRwLock<bool>,
#[pytraverse(skip)]
num_tickets: PyRwLock<i32>,
// Protocol version limits
#[pytraverse(skip)]
minimum_version: PyRwLock<i32>,
#[pytraverse(skip)]
maximum_version: PyRwLock<i32>,
// SNI callback for server-side (contains PyObjectRef - needs GC tracking)
sni_callback: PyRwLock<Option<PyObjectRef>>,
// Message callback for debugging (contains PyObjectRef - needs GC tracking)
msg_callback: PyRwLock<Option<PyObjectRef>>,
// ECDH curve name for key exchange
#[pytraverse(skip)]
ecdh_curve: PyRwLock<Option<String>>,
// Certificate statistics for cert_store_stats()
#[pytraverse(skip)]
ca_cert_count: PyRwLock<usize>, // Number of CA certificates
#[pytraverse(skip)]
x509_cert_count: PyRwLock<usize>, // Total number of certificates
// Session management
#[pytraverse(skip)]
client_session_cache: SessionCache,
// Rustls session store for actual TLS session resumption
#[pytraverse(skip)]
rustls_session_store: Arc<PythonClientSessionStore>,
// Rustls server session store for server-side session resumption
#[pytraverse(skip)]
rustls_server_session_store: Arc<rustls::server::ServerSessionMemoryCache>,
// Shared ticketer for TLS 1.2 session tickets
#[pytraverse(skip)]
server_ticketer: Arc<dyn rustls::server::ProducesTickets>,
// Server-side session statistics
#[pytraverse(skip)]
accept_count: AtomicUsize, // Total number of accepts
#[pytraverse(skip)]
session_hits: AtomicUsize, // Number of session reuses
// Cipher suite selection
/// Selected cipher suites (None = use all rustls defaults)
#[pytraverse(skip)]
selected_ciphers: PyRwLock<Option<Vec<rustls::SupportedCipherSuite>>>,
}
#[derive(FromArgs)]
struct WrapSocketArgs {
sock: PyObjectRef,
server_side: bool,
#[pyarg(positional, optional)]
server_hostname: OptionalArg<Option<PyStrRef>>,
#[pyarg(named, optional)]
owner: OptionalArg<PyObjectRef>,
#[pyarg(named, optional)]
session: OptionalArg<PyObjectRef>,
}
#[derive(FromArgs)]
struct WrapBioArgs {
incoming: PyRef<PyMemoryBIO>,
outgoing: PyRef<PyMemoryBIO>,
#[pyarg(named, optional)]
server_side: OptionalArg<bool>,
#[pyarg(named, optional)]
server_hostname: OptionalArg<Option<PyStrRef>>,
#[pyarg(named, optional)]
owner: OptionalArg<PyObjectRef>,
#[pyarg(named, optional)]
session: OptionalArg<PyObjectRef>,
}
#[derive(FromArgs)]
struct LoadVerifyLocationsArgs {
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
cafile: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
capath: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional, error_msg = "cadata should be a str or bytes")]
cadata: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
}
#[derive(FromArgs)]
struct LoadCertChainArgs {
#[pyarg(any, error_msg = "path should be a str or bytes")]
certfile: Either<PyStrRef, ArgBytesLike>,
#[pyarg(any, optional, error_msg = "path should be a str or bytes")]
keyfile: OptionalArg<Option<Either<PyStrRef, ArgBytesLike>>>,
#[pyarg(any, optional)]
password: OptionalArg<PyObjectRef>,
}
#[derive(FromArgs)]
struct GetCertArgs {
#[pyarg(any, optional)]
binary_form: OptionalArg<bool>,
}
#[pyclass(with(Constructor), flags(BASETYPE))]
impl PySSLContext {
// Helper method to convert DER certificate bytes to Python dict
fn cert_der_to_dict(&self, vm: &VirtualMachine, cert_der: &[u8]) -> PyResult<PyObjectRef> {
cert::cert_der_to_dict_helper(vm, cert_der)
}
#[pymethod]
fn __repr__(&self) -> String {
format!("<SSLContext(protocol={})>", self.protocol)
}
#[pygetset]
fn check_hostname(&self) -> bool {
*self.check_hostname.read()
}
#[pygetset(setter)]
fn set_check_hostname(&self, value: bool) {
*self.check_hostname.write() = value;
// When check_hostname is enabled, ensure verify_mode is at least CERT_REQUIRED
if value {
let current_verify_mode = *self.verify_mode.read();
if current_verify_mode == CERT_NONE {
*self.verify_mode.write() = CERT_REQUIRED;
}
}
}
#[pygetset]
fn verify_mode(&self) -> i32 {
*self.verify_mode.read()
}
#[pygetset(setter)]
fn set_verify_mode(&self, mode: i32, vm: &VirtualMachine) -> PyResult<()> {
if !(CERT_NONE..=CERT_REQUIRED).contains(&mode) {
return Err(vm.new_value_error("invalid verify mode"));
}
// Cannot set CERT_NONE when check_hostname is enabled
if mode == CERT_NONE && *self.check_hostname.read() {
return Err(vm.new_value_error(
"Cannot set verify_mode to CERT_NONE when check_hostname is enabled",
));
}
*self.verify_mode.write() = mode;
Ok(())
}
#[pygetset]
fn protocol(&self) -> i32 {
self.protocol
}
#[pygetset]
fn verify_flags(&self) -> i32 {
*self.verify_flags.read()
}
#[pygetset(setter)]
fn set_verify_flags(&self, value: i32) {
*self.verify_flags.write() = value;
}
#[pygetset]
fn post_handshake_auth(&self) -> bool {
*self.post_handshake_auth.read()
}
#[pygetset(setter)]
fn set_post_handshake_auth(&self, value: bool) {
*self.post_handshake_auth.write() = value;
}
#[pygetset]
fn num_tickets(&self) -> i32 {
*self.num_tickets.read()
}
#[pygetset(setter)]
fn set_num_tickets(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
if value < 0 {
return Err(vm.new_value_error("num_tickets must be a non-negative integer"));
}
if self.protocol != PROTOCOL_TLS_SERVER {
return Err(
vm.new_value_error("num_tickets can only be set on server-side contexts")
);
}
*self.num_tickets.write() = value;
Ok(())
}
#[pygetset]
fn options(&self) -> i32 {
*self.options.read()
}
#[pygetset(setter)]
fn set_options(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
// Validate that the value is non-negative
if value < 0 {
return Err(vm.new_overflow_error("options must be non-negative".to_owned()));
}
// Deprecated SSL/TLS protocol version options
let opt_no = OP_NO_SSLv2
| OP_NO_SSLv3
| OP_NO_TLSv1
| OP_NO_TLSv1_1
| OP_NO_TLSv1_2
| OP_NO_TLSv1_3;
// Get current options and calculate newly set bits
let old_opts = *self.options.read();
let set = !old_opts & value; // Bits being newly set
// Warn if any deprecated options are being newly set
if (set & opt_no) != 0 {
warnings::warn(
vm.ctx.exceptions.deprecation_warning,
"ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated".to_owned(),
2, // stack_level = 2
vm,
)?;
}
*self.options.write() = value;
Ok(())
}
#[pygetset]
fn minimum_version(&self) -> i32 {
let v = *self.minimum_version.read();
// return MINIMUM_SUPPORTED if value is 0
if v == 0 { PROTO_MINIMUM_SUPPORTED } else { v }
}
#[pygetset(setter)]
fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
// Validate that the value is a valid TLS version constant
// Valid values: 0 (default), -2 (MINIMUM_SUPPORTED), -1 (MAXIMUM_SUPPORTED),
// or 0x0300-0x0304 (SSLv3-TLSv1.3)
if value != 0
&& value != -2
&& value != -1
&& !(PROTO_SSLv3..=PROTO_TLSv1_3).contains(&value)
{
return Err(vm.new_value_error(format!("invalid protocol version: {value}")));
}
// Convert special values to rustls actual supported versions
// MINIMUM_SUPPORTED (-2) -> 0 (auto-negotiate)
// MAXIMUM_SUPPORTED (-1) -> MAXIMUM_VERSION (TLSv1.3)
let normalized_value = match value {
PROTO_MINIMUM_SUPPORTED => 0, // Auto-negotiate
PROTO_MAXIMUM_SUPPORTED => MAXIMUM_VERSION, // TLSv1.3
_ => value,
};
*self.minimum_version.write() = normalized_value;
Ok(())
}
#[pygetset]
fn maximum_version(&self) -> i32 {
let v = *self.maximum_version.read();
// return MAXIMUM_SUPPORTED if value is 0
if v == 0 { PROTO_MAXIMUM_SUPPORTED } else { v }
}
#[pygetset(setter)]
fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
// Validate that the value is a valid TLS version constant
// Valid values: 0 (default), -2 (MINIMUM_SUPPORTED), -1 (MAXIMUM_SUPPORTED),
// or 0x0300-0x0304 (SSLv3-TLSv1.3)
if value != 0
&& value != -2
&& value != -1
&& !(PROTO_SSLv3..=PROTO_TLSv1_3).contains(&value)
{
return Err(vm.new_value_error(format!("invalid protocol version: {value}")));
}
// Convert special values to rustls actual supported versions
// MAXIMUM_SUPPORTED (-1) -> 0 (auto-negotiate)
// MINIMUM_SUPPORTED (-2) -> MINIMUM_VERSION (TLSv1.2)
let normalized_value = match value {
PROTO_MAXIMUM_SUPPORTED => 0, // Auto-negotiate
PROTO_MINIMUM_SUPPORTED => MINIMUM_VERSION, // TLSv1.2
_ => value,
};
*self.maximum_version.write() = normalized_value;
Ok(())
}
#[pymethod]
fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> {
// Parse certfile argument (str or bytes) to path
let cert_path = Self::parse_path_arg(&args.certfile, vm)?;
// Parse keyfile argument (default to certfile if not provided)
let key_path = match args.keyfile {
OptionalArg::Present(Some(ref k)) => Self::parse_path_arg(k, vm)?,
_ => cert_path.clone(),
};
// Parse password argument (str, bytes-like, or callable)
// Callable passwords are NOT invoked immediately (lazy evaluation)
let (password_str, password_callable) =
Self::parse_password_argument(&args.password, vm)?;
// Validate immediate password length (limit: PEM_BUFSIZE = 1024 bytes)
if let Some(ref pwd) = password_str
&& pwd.len() > PEM_BUFSIZE
{
return Err(vm.new_value_error(format!(
"password cannot be longer than {PEM_BUFSIZE} bytes",
)));
}
// First attempt: Load with immediate password (or None if callable)
let mut result =
cert::load_cert_chain_from_file(&cert_path, &key_path, password_str.as_deref());
// If failed and callable exists, invoke it and retry
// This implements lazy evaluation: callable only invoked if password is actually needed
if result.is_err()
&& let Some(callable) = password_callable
{
// Invoke callable - exceptions propagate naturally
let pwd_result = callable.call((), vm)?;
// Convert callable result to string
let password_from_callable = if let Ok(pwd_str) =
PyStrRef::try_from_object(vm, pwd_result.clone())
{
pwd_str.as_str().to_owned()
} else if let Ok(pwd_bytes_like) = ArgBytesLike::try_from_object(vm, pwd_result) {
String::from_utf8(pwd_bytes_like.borrow_buf().to_vec()).map_err(|_| {
vm.new_type_error(
"password callback returned invalid UTF-8 bytes".to_owned(),
)
})?
} else {
return Err(vm.new_type_error(
"password callback must return a string or bytes".to_owned(),
));
};
// Validate callable password length
if password_from_callable.len() > PEM_BUFSIZE {
return Err(vm.new_value_error(format!(
"password cannot be longer than {PEM_BUFSIZE} bytes",
)));
}
// Retry with callable password
result = cert::load_cert_chain_from_file(
&cert_path,
&key_path,
Some(&password_from_callable),
);
}
// Process result
let (certs, key) = result.map_err(|e| {
// Try to downcast to io::Error to preserve errno information
if let Ok(io_err) = e.downcast::<std::io::Error>() {
match io_err.kind() {
// File access errors (NotFound, PermissionDenied) - preserve errno
std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => {
io_err.into_pyexception(vm)
}
// Other io::Error types
std::io::ErrorKind::Other => {
let msg = io_err.to_string();
if msg.contains("Failed to decrypt") || msg.contains("wrong password") {
// Wrong password error
vm.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
msg,
)
.upcast()
} else {
// [SSL] PEM lib
super::compat::SslError::create_ssl_error_with_reason(
vm,
Some("SSL"),
"",
"PEM lib",
)
}
}
// PEM parsing errors - [SSL] PEM lib
_ => super::compat::SslError::create_ssl_error_with_reason(
vm,
Some("SSL"),
"",
"PEM lib",
),
}
} else {
// Unknown error type - [SSL] PEM lib
super::compat::SslError::create_ssl_error_with_reason(
vm,
Some("SSL"),
"",
"PEM lib",
)
}
})?;
// Validate certificate and key match
cert::validate_cert_key_match(&certs, &key).map_err(|e| {
let msg = if e.contains("key values mismatch") {
"[SSL: KEY_VALUES_MISMATCH] key values mismatch".to_owned()
} else {
e
};
vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), Some(0), msg)
.upcast()
})?;
// Auto-build certificate chain: if only leaf cert is in file, try to add CA certs
// This matches OpenSSL behavior where it automatically includes intermediate/CA certs
let mut full_chain = certs.clone();
if full_chain.len() == 1 {
// Only have leaf cert, try to build chain from CA certs
let ca_certs_der = self.ca_certs_der.read();
if !ca_certs_der.is_empty() {
// Use build_verified_chain to construct full chain
let chain_result = cert::build_verified_chain(&full_chain, &ca_certs_der);
if chain_result.len() > 1 {
// Successfully built a longer chain
full_chain = chain_result.into_iter().map(CertificateDer::from).collect();
}
}
}
// Additional validation: Create CertifiedKey to ensure rustls accepts it
let signing_key =
rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|_| {
vm.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"[SSL: KEY_VALUES_MISMATCH] key values mismatch",
)
.upcast()
})?;
let certified_key = CertifiedKey::new(full_chain.clone(), signing_key);
if certified_key.keys_match().is_err() {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"[SSL: KEY_VALUES_MISMATCH] key values mismatch",
)
.upcast());
}
// Add cert/key pair to collection (OpenSSL allows multiple cert/key pairs)
// Store both CertifiedKey (for server) and PrivateKeyDer (for client mTLS)
let cert_der = &full_chain[0];
let mut cert_keys = self.cert_keys.write();
// Remove any existing cert/key pair with the same certificate
// (This allows updating cert/key pair without duplicating)
cert_keys.retain(|(existing, _)| &existing.cert[0] != cert_der);
// Add new cert/key pair as tuple
cert_keys.push((Arc::new(certified_key), key));
Ok(())
}
#[pymethod]
fn load_verify_locations(
&self,
args: LoadVerifyLocationsArgs,
vm: &VirtualMachine,
) -> PyResult<()> {
// Check that at least one argument is provided
let has_cafile = matches!(&args.cafile, OptionalArg::Present(Some(_)));
let has_capath = matches!(&args.capath, OptionalArg::Present(Some(_)));
let has_cadata = matches!(&args.cadata, OptionalArg::Present(Some(_)));
if !has_cafile && !has_capath && !has_cadata {
return Err(
vm.new_type_error("cafile, capath and cadata cannot be all omitted".to_owned())
);
}
// Parse arguments BEFORE acquiring locks to reduce lock scope
let cafile_path = if let OptionalArg::Present(Some(ref cafile_obj)) = args.cafile {
Some(Self::parse_path_arg(cafile_obj, vm)?)
} else {
None
};
let capath_dir = if let OptionalArg::Present(Some(ref capath_obj)) = args.capath {
Some(Self::parse_path_arg(capath_obj, vm)?)
} else {
None
};
let cadata_parsed = if let OptionalArg::Present(Some(ref cadata_obj)) = args.cadata {
let is_string = matches!(cadata_obj, Either::A(_));
let data_vec = self.parse_cadata_arg(cadata_obj, vm)?;
Some((data_vec, is_string))
} else {
None
};
// Check for CRL before acquiring main locks
let (crl_opt, cafile_is_crl) = if let Some(ref path) = cafile_path {
let crl = self.load_crl_from_file(path, vm)?;
let is_crl = crl.is_some();
(crl, is_crl)
} else {
(None, false)
};
// If it's a CRL, just add it (separate lock, no conflict with root_store)
if let Some(crl) = crl_opt {
self.crls.write().push(crl);
}
// Now acquire write locks for certificate loading
let mut root_store = self.root_certs.write();
let mut ca_certs_der = self.ca_certs_der.write();
// Load from file (if not CRL)
if let Some(ref path) = cafile_path
&& !cafile_is_crl
{
// Not a CRL, load as certificate
let stats =
self.load_certs_from_file_helper(&mut root_store, &mut ca_certs_der, path, vm)?;
self.update_cert_stats(stats);
}
// Load from directory (don't add to ca_certs_der)
if let Some(ref dir_path) = capath_dir {
let stats = self.load_certs_from_dir_helper(&mut root_store, dir_path, vm)?;
self.update_cert_stats(stats);
}
// Load from bytes or str
if let Some((ref data_vec, is_string)) = cadata_parsed {
let stats = self.load_certs_from_bytes_helper(
&mut root_store,
&mut ca_certs_der,
data_vec,
is_string, // PEM only for strings
vm,
)?;
self.update_cert_stats(stats);
}
Ok(())
}
/// Helper: Get path from Python's os.environ
fn get_env_path(
environ: &PyObject,
var_name: &str,
vm: &VirtualMachine,
) -> PyResult<String> {
let path_obj = environ.get_item(var_name, vm)?;
path_obj.try_into_value(vm)
}
/// Helper: Try to load certificates from Python's os.environ variables
///
/// Returns true if certificates were successfully loaded.
///
/// We use Python's os.environ instead of Rust's std::env
/// because Python code can modify os.environ at runtime (e.g.,
/// `os.environ['SSL_CERT_FILE'] = '/path'`), but rustls-native-certs uses
/// std::env which only sees the process environment at startup.
fn try_load_from_python_environ(
&self,
loader: &mut cert::CertLoader<'_>,
vm: &VirtualMachine,
) -> PyResult<bool> {
use std::path::Path;
let os_module = vm.import("os", 0)?;
let environ = os_module.get_attr("environ", vm)?;
// Try SSL_CERT_FILE first
if let Ok(cert_file) = Self::get_env_path(&environ, "SSL_CERT_FILE", vm)
&& Path::new(&cert_file).exists()
&& let Ok(stats) = loader.load_from_file(&cert_file)
{
self.update_cert_stats(stats);
return Ok(true);
}
// Try SSL_CERT_DIR (only if SSL_CERT_FILE didn't work)
if let Ok(cert_dir) = Self::get_env_path(&environ, "SSL_CERT_DIR", vm)
&& Path::new(&cert_dir).is_dir()
&& let Ok(stats) = loader.load_from_dir(&cert_dir)
{
self.update_cert_stats(stats);
return Ok(true);
}
Ok(false)
}
/// Helper: Load system certificates using rustls-native-certs
///
/// This uses platform-specific methods:
/// - Linux: openssl-probe to find certificate files
/// - macOS: Keychain API
/// - Windows: System certificate store (ROOT + CA stores)
fn load_system_certificates(
&self,
store: &mut rustls::RootCertStore,
vm: &VirtualMachine,
) -> PyResult<()> {
#[cfg(windows)]
{
// Windows: Use schannel to load from both ROOT and CA stores
use schannel::cert_store::CertStore;
let store_names = ["ROOT", "CA"];
let open_fns = [CertStore::open_current_user, CertStore::open_local_machine];
for store_name in store_names {
for open_fn in &open_fns {
if let Ok(cert_store) = open_fn(store_name) {
for cert_ctx in cert_store.certs() {
let der_bytes = cert_ctx.to_der();
let cert =
rustls::pki_types::CertificateDer::from(der_bytes.to_vec());
let is_ca = cert::is_ca_certificate(cert.as_ref());
if store.add(cert).is_ok() {
*self.x509_cert_count.write() += 1;
if is_ca {
*self.ca_cert_count.write() += 1;
}
}
}
}
}
}
if *self.x509_cert_count.read() == 0 {
return Err(vm.new_os_error("Failed to load certificates from Windows store"));
}
Ok(())
}
#[cfg(not(windows))]
{
let result = rustls_native_certs::load_native_certs();
// Load successfully found certificates
for cert in result.certs {
let is_ca = cert::is_ca_certificate(cert.as_ref());
if store.add(cert).is_ok() {
*self.x509_cert_count.write() += 1;
if is_ca {
*self.ca_cert_count.write() += 1;
}
}
}
// If there were errors but some certs loaded, just continue
// If NO certs loaded and there were errors, report the first error
if *self.x509_cert_count.read() == 0 && !result.errors.is_empty() {
return Err(vm.new_os_error(format!(
"Failed to load native certificates: {}",
result.errors[0]
)));
}
Ok(())
}
}
#[pymethod]
fn load_default_certs(
&self,
_purpose: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut store = self.root_certs.write();
#[cfg(windows)]
{
// Windows: Load system certificates first, then additionally load from env
// see: test_load_default_certs_env_windows
let _ = self.load_system_certificates(&mut store, vm);
let mut lazy_ca_certs = Vec::new();
let mut loader = cert::CertLoader::new(&mut store, &mut lazy_ca_certs);
let _ = self.try_load_from_python_environ(&mut loader, vm)?;
}
#[cfg(not(windows))]
{
// Non-Windows: Try env vars first; only fallback to system certs if not set
// see: test_load_default_certs_env
let mut lazy_ca_certs = Vec::new();
let mut loader = cert::CertLoader::new(&mut store, &mut lazy_ca_certs);
let loaded = self.try_load_from_python_environ(&mut loader, vm)?;
if !loaded {
let _ = self.load_system_certificates(&mut store, vm);
}
}
// If no certificates were loaded from system, fallback to webpki-roots (Mozilla CA bundle)
// This ensures we always have some trusted root certificates even if system cert loading fails
if *self.x509_cert_count.read() == 0 {
use webpki_roots;
// webpki_roots provides TLS_SERVER_ROOTS as &[TrustAnchor]
// We can use extend() to add them to the RootCertStore
let webpki_count = webpki_roots::TLS_SERVER_ROOTS.len();
store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
*self.x509_cert_count.write() += webpki_count;
*self.ca_cert_count.write() += webpki_count;
}
Ok(())
}
#[pymethod]
fn set_alpn_protocols(&self, protocols: PyListRef, vm: &VirtualMachine) -> PyResult<()> {
let mut alpn_list = Vec::new();
for item in protocols.borrow_vec().iter() {
let bytes = ArgBytesLike::try_from_object(vm, item.clone())?;
alpn_list.push(bytes.borrow_buf().to_vec());
}
*self.alpn_protocols.write() = alpn_list;
Ok(())
}
#[pymethod]
fn _set_alpn_protocols(&self, protos: ArgBytesLike, vm: &VirtualMachine) -> PyResult<()> {
let bytes = protos.borrow_buf();
let alpn_list = parse_length_prefixed_alpn(&bytes, vm)?;
*self.alpn_protocols.write() = alpn_list;
Ok(())
}
#[pymethod]
fn set_ciphers(&self, ciphers: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
let cipher_str = ciphers.as_str();
// Parse cipher string and store selected ciphers
let selected_ciphers = parse_cipher_string(cipher_str).map_err(|e| {
vm.new_os_subtype_error(PySSLError::class(&vm.ctx).to_owned(), None, e)
.upcast()
})?;
// Store in context
*self.selected_ciphers.write() = Some(selected_ciphers);
Ok(())
}
#[pymethod]
fn get_ciphers(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
// Dynamically generate cipher list from rustls ALL_CIPHER_SUITES
// This automatically includes all cipher suites supported by the current rustls version
use rustls::crypto::aws_lc_rs::ALL_CIPHER_SUITES;
let cipher_list = ALL_CIPHER_SUITES
.iter()
.map(|suite| {
// Extract cipher information using unified helper
let cipher_info = extract_cipher_info(suite);
// Convert to OpenSSL-style name
// e.g., "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" -> "ECDHE-RSA-AES128-GCM-SHA256"
let openssl_name = normalize_cipher_name(&cipher_info.name);
// Determine key exchange and auth methods
let (kx, auth) = if cipher_info.protocol == "TLSv1.3" {
// TLS 1.3 doesn't distinguish - all use modern algos
("any", "any")
} else if cipher_info.name.contains("ECDHE") {
// TLS 1.2 with ECDHE
let auth = if cipher_info.name.contains("ECDSA") {
"ECDSA"
} else if cipher_info.name.contains("RSA") {
"RSA"
} else {
"any"
};
("ECDH", auth)
} else {
("any", "any")
};
// Build description string
// Format: "{name} {protocol} Kx={kx} Au={auth} Enc={enc} Mac={mac}"
let enc = get_cipher_encryption_desc(&openssl_name);
let description = format!(
"{} {} Kx={} Au={} Enc={} Mac=AEAD",
openssl_name, cipher_info.protocol, kx, auth, enc
);
// Create cipher dict
let dict = vm.ctx.new_dict();
dict.set_item("name", vm.ctx.new_str(openssl_name).into(), vm)
.unwrap();
dict.set_item("protocol", vm.ctx.new_str(cipher_info.protocol).into(), vm)
.unwrap();
dict.set_item("id", vm.ctx.new_int(0).into(), vm).unwrap(); // Placeholder ID
dict.set_item("strength_bits", vm.ctx.new_int(cipher_info.bits).into(), vm)
.unwrap();
dict.set_item("alg_bits", vm.ctx.new_int(cipher_info.bits).into(), vm)
.unwrap();
dict.set_item("description", vm.ctx.new_str(description).into(), vm)
.unwrap();
dict.into()
})
.collect::<Vec<_>>();
Ok(PyListRef::from(vm.ctx.new_list(cipher_list)))
}
#[pymethod]
fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> {
// Just call load_default_certs
self.load_default_certs(OptionalArg::Missing, vm)
}
#[pymethod]
fn cert_store_stats(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Use the certificate counters that are updated in load_verify_locations
let x509_count = *self.x509_cert_count.read() as i32;
let ca_count = *self.ca_cert_count.read() as i32;
let dict = vm.ctx.new_dict();
dict.set_item("x509", vm.ctx.new_int(x509_count).into(), vm)?;
dict.set_item("crl", vm.ctx.new_int(0).into(), vm)?; // CRL not supported
dict.set_item("x509_ca", vm.ctx.new_int(ca_count).into(), vm)?;
Ok(dict.into())
}
#[pymethod]
fn session_stats(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Return session statistics
// NOTE: This is a partial implementation - rustls doesn't expose all OpenSSL stats
let dict = vm.ctx.new_dict();
// Number of sessions currently in the cache
let session_count = self.client_session_cache.read().len() as i32;
dict.set_item("number", vm.ctx.new_int(session_count).into(), vm)?;
// Client-side statistics (not tracked separately in this implementation)
dict.set_item("connect", vm.ctx.new_int(0).into(), vm)?;
dict.set_item("connect_good", vm.ctx.new_int(0).into(), vm)?;
dict.set_item("connect_renegotiate", vm.ctx.new_int(0).into(), vm)?; // rustls doesn't support renegotiation
// Server-side statistics
let accept_count = self.accept_count.load(Ordering::SeqCst) as i32;
dict.set_item("accept", vm.ctx.new_int(accept_count).into(), vm)?;
dict.set_item("accept_good", vm.ctx.new_int(accept_count).into(), vm)?; // Assume all accepts are good
dict.set_item("accept_renegotiate", vm.ctx.new_int(0).into(), vm)?; // rustls doesn't support renegotiation
// Session reuse statistics
let hits = self.session_hits.load(Ordering::SeqCst) as i32;
dict.set_item("hits", vm.ctx.new_int(hits).into(), vm)?;
// Misses, timeouts, and cache_full are not tracked in this implementation
dict.set_item("misses", vm.ctx.new_int(0).into(), vm)?;
dict.set_item("timeouts", vm.ctx.new_int(0).into(), vm)?;
dict.set_item("cache_full", vm.ctx.new_int(0).into(), vm)?;
Ok(dict.into())
}
#[pygetset]
fn sni_callback(&self) -> Option<PyObjectRef> {
self.sni_callback.read().clone()
}
#[pygetset(setter)]
fn set_sni_callback(
&self,
callback: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
// Validate callback is callable or None
if let Some(ref cb) = callback
&& !cb.is(vm.ctx.types.none_type)
&& !cb.is_callable()
{
return Err(vm.new_type_error("sni_callback must be callable or None"));
}
*self.sni_callback.write() = callback;
Ok(())
}
#[pymethod]
fn set_servername_callback(
&self,
callback: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
// Alias for set_sni_callback
self.set_sni_callback(callback, vm)
}
#[pygetset]
fn security_level(&self) -> i32 {
// rustls uses a fixed security level
// Return 2 which is a reasonable default (equivalent to OpenSSL 1.1.0+ level 2)
2
}
#[pygetset]
fn _msg_callback(&self) -> Option<PyObjectRef> {
self.msg_callback.read().clone()
}
#[pygetset(setter)]
fn set__msg_callback(
&self,
callback: Option<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
// Validate callback is callable or None
if let Some(ref cb) = callback
&& !cb.is(vm.ctx.types.none_type)
&& !cb.is_callable()
{
return Err(vm.new_type_error("msg_callback must be callable or None"));
}
*self.msg_callback.write() = callback;
Ok(())
}
#[pymethod]
fn get_ca_certs(&self, args: GetCertArgs, vm: &VirtualMachine) -> PyResult<PyListRef> {
let binary_form = args.binary_form.unwrap_or(false);
let ca_certs_der = self.ca_certs_der.read();
let mut certs = Vec::new();
for cert_der in ca_certs_der.iter() {
// Parse certificate to check if it's a CA and get info
match x509_parser::parse_x509_certificate(cert_der) {
Ok((_, cert)) => {
// Check if this is a CA certificate (BasicConstraints: CA=TRUE)
let is_ca = if let Ok(Some(bc_ext)) = cert.basic_constraints() {
bc_ext.value.ca
} else {
false
};
// Only include CA certificates
if !is_ca {
continue;
}
if binary_form {
// Return DER-encoded certificate as bytes
certs.push(vm.ctx.new_bytes(cert_der.clone()).into());
} else {
// Return certificate as dict (use helper from _test_decode_cert)
let dict = self.cert_der_to_dict(vm, cert_der)?;
certs.push(dict);
}
}
Err(_) => {
// Skip invalid certificates
continue;
}
}
}
Ok(PyListRef::from(vm.ctx.new_list(certs)))
}
#[pymethod]
fn load_dh_params(&self, filepath: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// Validate filepath is not None
if vm.is_none(&filepath) {
return Err(vm.new_type_error("DH params filepath cannot be None".to_owned()));
}
// Validate filepath is str or bytes
let path_str = if let Ok(s) = PyStrRef::try_from_object(vm, filepath.clone()) {
s.as_str().to_owned()
} else if let Ok(b) = ArgBytesLike::try_from_object(vm, filepath) {
String::from_utf8(b.borrow_buf().to_vec())
.map_err(|_| vm.new_value_error("Invalid path encoding".to_owned()))?
} else {
return Err(vm.new_type_error("DH params filepath must be str or bytes".to_owned()));
};
// Check if file exists
if !std::path::Path::new(&path_str).exists() {
// Create FileNotFoundError with errno=ENOENT (2)
let exc = vm.new_os_subtype_error(
vm.ctx.exceptions.file_not_found_error.to_owned(),
Some(2), // errno = ENOENT (2)
"No such file or directory",
);
// Set filename attribute
let _ = exc
.as_object()
.set_attr("filename", vm.ctx.new_str(path_str.clone()), vm);
return Err(exc.upcast());
}
// Validate that the file contains DH parameters
// Read the file and check for DH PARAMETERS header
let contents =
std::fs::read_to_string(&path_str).map_err(|e| vm.new_os_error(e.to_string()))?;
if !contents.contains("BEGIN DH PARAMETERS")
&& !contents.contains("BEGIN X9.42 DH PARAMETERS")
{
// File exists but doesn't contain DH parameters - raise SSLError
// [PEM: NO_START_LINE] no start line
return Err(super::compat::SslError::create_ssl_error_with_reason(
vm,
Some("PEM"),
"NO_START_LINE",
"[PEM: NO_START_LINE] no start line",
));
}
// rustls doesn't use DH parameters (it uses ECDHE for key exchange)
// This is a no-op for compatibility with OpenSSL-based code
Ok(())
}
#[pymethod]
fn set_ecdh_curve(&self, name: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// Validate name is not None
if vm.is_none(&name) {
return Err(vm.new_type_error("ECDH curve name cannot be None".to_owned()));
}
// Validate name is str or bytes
let curve_name = if let Ok(s) = PyStrRef::try_from_object(vm, name.clone()) {
s.as_str().to_owned()
} else if let Ok(b) = ArgBytesLike::try_from_object(vm, name) {
String::from_utf8(b.borrow_buf().to_vec())
.map_err(|_| vm.new_value_error("Invalid curve name encoding".to_owned()))?
} else {
return Err(vm.new_type_error("ECDH curve name must be str or bytes".to_owned()));
};
// Validate curve name (common curves for compatibility)
// rustls supports: X25519, secp256r1 (prime256v1), secp384r1
let valid_curves = [
"prime256v1",
"secp256r1",
"prime384v1",
"secp384r1",
"prime521v1",
"secp521r1",
"X25519",
"x25519",
"x448", // For future compatibility
];
if !valid_curves.contains(&curve_name.as_str()) {
return Err(vm.new_value_error(format!("unknown curve name '{curve_name}'")));
}
// Store the curve name to be used during handshake
// This will limit the key exchange groups offered/accepted
*self.ecdh_curve.write() = Some(curve_name);
Ok(())
}
#[pymethod]
fn _wrap_socket(
zelf: PyRef<Self>,
args: WrapSocketArgs,
vm: &VirtualMachine,
) -> PyResult<PyRef<PySSLSocket>> {
// Convert server_hostname to Option<String>
// Handle both missing argument and None value
let hostname = match args.server_hostname.into_option().flatten() {
Some(hostname_str) => {
let hostname = hostname_str.as_str();
// Validate hostname
if hostname.is_empty() {
return Err(vm.new_value_error("server_hostname cannot be an empty string"));
}
// Check if it starts with a dot
if hostname.starts_with('.') {
return Err(vm.new_value_error("server_hostname cannot start with a dot"));
}
// IP addresses are allowed
// SNI will not be sent for IP addresses
// Check for NULL bytes
if hostname.contains('\0') {
return Err(vm.new_type_error("embedded null character"));
}
Some(hostname.to_string())
}
None => None,
};
// Validate socket type and context protocol
if args.server_side && zelf.protocol == PROTOCOL_TLS_CLIENT {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"Cannot create a server socket with a PROTOCOL_TLS_CLIENT context",
)
.upcast());
}
if !args.server_side && zelf.protocol == PROTOCOL_TLS_SERVER {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"Cannot create a client socket with a PROTOCOL_TLS_SERVER context",
)
.upcast());
}
// Create _SSLSocket instance
let ssl_socket = PySSLSocket {
sock: args.sock.clone(),
context: PyRwLock::new(zelf),
server_side: args.server_side,
server_hostname: PyRwLock::new(hostname),
connection: PyMutex::new(None),
handshake_done: PyMutex::new(false),
session_was_reused: PyMutex::new(false),
owner: PyRwLock::new(args.owner.into_option()),
// Filter out Python None objects - only store actual SSLSession objects
session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))),
verified_chain: PyRwLock::new(None),
incoming_bio: None,
outgoing_bio: None,
sni_state: PyRwLock::new(None),
pending_context: PyRwLock::new(None),
client_hello_buffer: PyMutex::new(None),
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
pending_tls_output: PyMutex::new(Vec::new()),
write_buffered_len: PyMutex::new(0),
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
};
// Create PyRef with correct type
let ssl_socket_ref = ssl_socket
.into_ref_with_type(vm, vm.class("_ssl", "_SSLSocket"))
.map_err(|_| vm.new_type_error("Failed to create SSLSocket"))?;
Ok(ssl_socket_ref)
}
#[pymethod]
fn _wrap_bio(
zelf: PyRef<Self>,
args: WrapBioArgs,
vm: &VirtualMachine,
) -> PyResult<PyRef<PySSLSocket>> {
// Convert server_hostname to Option<String>
// Handle both missing argument and None value
let hostname = match args.server_hostname.into_option().flatten() {
Some(hostname_str) => {
let hostname = hostname_str.as_str();
validate_hostname(hostname, vm)?;
Some(hostname.to_string())
}
None => None,
};
// Extract server_side value
let server_side = args.server_side.unwrap_or(false);
// Validate socket type and context protocol
if server_side && zelf.protocol == PROTOCOL_TLS_CLIENT {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"Cannot create a server socket with a PROTOCOL_TLS_CLIENT context",
)
.upcast());
}
if !server_side && zelf.protocol == PROTOCOL_TLS_SERVER {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"Cannot create a client socket with a PROTOCOL_TLS_SERVER context",
)
.upcast());
}
// Create _SSLSocket instance with BIO mode
let ssl_socket = PySSLSocket {
sock: vm.ctx.none(), // No socket in BIO mode
context: PyRwLock::new(zelf),
server_side,
server_hostname: PyRwLock::new(hostname),
connection: PyMutex::new(None),
handshake_done: PyMutex::new(false),
session_was_reused: PyMutex::new(false),
owner: PyRwLock::new(args.owner.into_option()),
// Filter out Python None objects - only store actual SSLSession objects
session: PyRwLock::new(args.session.into_option().filter(|s| !vm.is_none(s))),
verified_chain: PyRwLock::new(None),
incoming_bio: Some(args.incoming),
outgoing_bio: Some(args.outgoing),
sni_state: PyRwLock::new(None),
pending_context: PyRwLock::new(None),
client_hello_buffer: PyMutex::new(None),
shutdown_state: PyMutex::new(ShutdownState::NotStarted),
pending_tls_output: PyMutex::new(Vec::new()),
write_buffered_len: PyMutex::new(0),
deferred_cert_error: Arc::new(ParkingRwLock::new(None)),
};
let ssl_socket_ref = ssl_socket
.into_ref_with_type(vm, vm.class("_ssl", "_SSLSocket"))
.map_err(|_| vm.new_type_error("Failed to create SSLSocket"))?;
Ok(ssl_socket_ref)
}
// Helper functions (private):
/// Parse path argument (str or bytes) to string
fn parse_path_arg(
arg: &Either<PyStrRef, ArgBytesLike>,
vm: &VirtualMachine,
) -> PyResult<String> {
match arg {
Either::A(s) => Ok(s.as_str().to_owned()),
Either::B(b) => String::from_utf8(b.borrow_buf().to_vec())
.map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned())),
}
}
/// Parse password argument (str, bytes-like, or callable)
///
/// Returns (immediate_password, callable) where:
/// - immediate_password: Some(string) if password is str/bytes, None if callable
/// - callable: Some(PyObjectRef) if password is callable, None otherwise
fn parse_password_argument(
password: &OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<(Option<String>, Option<PyObjectRef>)> {
match password {
OptionalArg::Present(p) => {
// Try string first
if let Ok(pwd_str) = PyStrRef::try_from_object(vm, p.clone()) {
Ok((Some(pwd_str.as_str().to_owned()), None))
}
// Try bytes-like
else if let Ok(pwd_bytes_like) = ArgBytesLike::try_from_object(vm, p.clone())
{
let pwd = String::from_utf8(pwd_bytes_like.borrow_buf().to_vec()).map_err(
|_| vm.new_type_error("password bytes must be valid UTF-8".to_owned()),
)?;
Ok((Some(pwd), None))
}
// Try callable
else if p.is_callable() {
Ok((None, Some(p.clone())))
} else {
Err(vm.new_type_error(
"password should be a string, bytes, or callable".to_owned(),
))
}
}
_ => Ok((None, None)),
}
}
/// Helper: Load certificates from file into existing store
fn load_certs_from_file_helper(
&self,
root_store: &mut RootCertStore,
ca_certs_der: &mut Vec<Vec<u8>>,
path: &str,
vm: &VirtualMachine,
) -> PyResult<cert::CertStats> {
let mut loader = cert::CertLoader::new(root_store, ca_certs_der);
loader.load_from_file(path).map_err(|e| {
// Preserve errno for file access errors (NotFound, PermissionDenied)
match e.kind() {
std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => {
e.into_pyexception(vm)
}
// PEM parsing errors
_ => super::compat::SslError::create_ssl_error_with_reason(
vm,
Some("X509"),
"",
"PEM lib",
),
}
})
}
/// Helper: Load certificates from directory into existing store
fn load_certs_from_dir_helper(
&self,
root_store: &mut RootCertStore,
path: &str,
vm: &VirtualMachine,
) -> PyResult<cert::CertStats> {
// Load certs and store them in capath_certs_der for lazy loading simulation
// (CPython only returns these in get_ca_certs() after they're used in handshake)
let mut capath_certs = Vec::new();
let mut loader = cert::CertLoader::new(root_store, &mut capath_certs);
let stats = loader
.load_from_dir(path)
.map_err(|e| e.into_pyexception(vm))?;
// Store loaded certs for potential tracking after handshake
*self.capath_certs_der.write() = capath_certs;
Ok(stats)
}
/// Helper: Load certificates from bytes into existing store
fn load_certs_from_bytes_helper(
&self,
root_store: &mut RootCertStore,
ca_certs_der: &mut Vec<Vec<u8>>,
data: &[u8],
pem_only: bool,
vm: &VirtualMachine,
) -> PyResult<cert::CertStats> {
let mut loader = cert::CertLoader::new(root_store, ca_certs_der);
// treat_all_as_ca=true: CPython counts all certificates loaded via cadata as CA certs
// regardless of their Basic Constraints extension
// pem_only=true for string input
loader
.load_from_bytes_ex(data, true, pem_only)
.map_err(|e| {
// Preserve specific error messages from cert.rs
let err_msg = e.to_string();
if err_msg.contains("no start line") {
vm.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"no start line: cadata does not contain a certificate",
)
.upcast()
} else if err_msg.contains("not enough data") {
vm.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"not enough data: cadata does not contain a certificate",
)
.upcast()
} else {
vm.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
err_msg,
)
.upcast()
}
})
}
/// Helper: Try to parse data as CRL (PEM or DER format)
fn try_parse_crl(
&self,
data: &[u8],
) -> Result<CertificateRevocationListDer<'static>, String> {
// Try PEM format first
let mut cursor = std::io::Cursor::new(data);
let mut crl_iter = rustls_pemfile::crls(&mut cursor);
if let Some(Ok(crl)) = crl_iter.next() {
return Ok(crl);
}
// Try DER format
// Basic validation: CRL should start with SEQUENCE tag (0x30)
if !data.is_empty() && data[0] == 0x30 {
return Ok(CertificateRevocationListDer::from(data.to_vec()));
}
Err("Not a valid CRL file".to_string())
}
/// Helper: Load CRL from file
fn load_crl_from_file(
&self,
path: &str,
vm: &VirtualMachine,
) -> PyResult<Option<CertificateRevocationListDer<'static>>> {
let data = std::fs::read(path).map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound | std::io::ErrorKind::PermissionDenied => {
e.into_pyexception(vm)
}
_ => vm.new_os_error(e.to_string()),
})?;
match self.try_parse_crl(&data) {
Ok(crl) => Ok(Some(crl)),
Err(_) => Ok(None), // Not a CRL file, might be a cert file
}
}
/// Helper: Parse cadata argument (str or bytes)
fn parse_cadata_arg(
&self,
arg: &Either<PyStrRef, ArgBytesLike>,
_vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
match arg {
Either::A(s) => Ok(s.as_str().as_bytes().to_vec()),
Either::B(b) => Ok(b.borrow_buf().to_vec()),
}
}
/// Helper: Update certificate statistics
fn update_cert_stats(&self, stats: cert::CertStats) {
*self.x509_cert_count.write() += stats.total_certs;
*self.ca_cert_count.write() += stats.ca_certs;
}
}
impl Constructor for PySSLContext {
type Args = (i32,);
fn py_new(
_cls: &Py<PyType>,
(protocol,): Self::Args,
vm: &VirtualMachine,
) -> PyResult<Self> {
// Validate protocol
match protocol {
PROTOCOL_TLS | PROTOCOL_TLS_CLIENT | PROTOCOL_TLS_SERVER | PROTOCOL_TLSv1_2
| PROTOCOL_TLSv1_3 => {
// Valid protocols
}
PROTOCOL_TLSv1 | PROTOCOL_TLSv1_1 => {
return Err(vm.new_value_error(
"TLS 1.0 and 1.1 are not supported by rustls for security reasons",
));
}
_ => {
return Err(vm.new_value_error(format!("invalid protocol version: {protocol}")));
}
}
// Set default options
// OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_COMPRESSION |
// OP_CIPHER_SERVER_PREFERENCE | OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
// OP_ENABLE_MIDDLEBOX_COMPAT
let default_options = OP_ALL
| OP_NO_SSLv2
| OP_NO_SSLv3
| OP_NO_COMPRESSION
| OP_CIPHER_SERVER_PREFERENCE
| OP_SINGLE_DH_USE
| OP_SINGLE_ECDH_USE
| OP_ENABLE_MIDDLEBOX_COMPAT;
// Set default verify_mode based on protocol
// PROTOCOL_TLS_CLIENT defaults to CERT_REQUIRED
// PROTOCOL_TLS_SERVER defaults to CERT_NONE
let default_verify_mode = if protocol == PROTOCOL_TLS_CLIENT {
CERT_REQUIRED
} else {
CERT_NONE
};
// Set default verify_flags based on protocol
// Both PROTOCOL_TLS_CLIENT and PROTOCOL_TLS_SERVER only set VERIFY_X509_TRUSTED_FIRST
// Note: VERIFY_X509_PARTIAL_CHAIN and VERIFY_X509_STRICT are NOT set here
// - they're only added by create_default_context() in Python's ssl.py
let default_verify_flags = VERIFY_DEFAULT | VERIFY_X509_TRUSTED_FIRST;
// Set minimum and maximum protocol versions based on protocol constant
// specific protocol versions fix both min and max
let (min_version, max_version) = match protocol {
PROTOCOL_TLSv1_2 => (PROTO_TLSv1_2, PROTO_TLSv1_2), // Only TLS 1.2
PROTOCOL_TLSv1_3 => (PROTO_TLSv1_3, PROTO_TLSv1_3), // Only TLS 1.3
_ => (PROTO_MINIMUM_SUPPORTED, PROTO_MAXIMUM_SUPPORTED), // Auto-negotiate
};
// IMPORTANT: Create shared session cache BEFORE PySSLContext
// Both client_session_cache and PythonClientSessionStore.session_cache
// MUST point to the same HashMap to ensure Python-level and Rustls-level
// sessions are synchronized
let shared_session_cache = Arc::new(ParkingRwLock::new(HashMap::new()));
let rustls_client_store = Arc::new(PythonClientSessionStore {
inner: Arc::new(rustls::client::ClientSessionMemoryCache::new(
SSL_SESSION_CACHE_SIZE,
)),
session_cache: shared_session_cache.clone(),
});
Ok(PySSLContext {
protocol,
check_hostname: PyRwLock::new(protocol == PROTOCOL_TLS_CLIENT),
verify_mode: PyRwLock::new(default_verify_mode),
verify_flags: PyRwLock::new(default_verify_flags),
client_config: PyRwLock::new(None),
server_config: PyRwLock::new(None),
root_certs: PyRwLock::new(RootCertStore::empty()),
ca_certs_der: PyRwLock::new(Vec::new()),
capath_certs_der: PyRwLock::new(Vec::new()),
crls: PyRwLock::new(Vec::new()),
cert_keys: PyRwLock::new(Vec::new()),
options: PyRwLock::new(default_options),
alpn_protocols: PyRwLock::new(Vec::new()),
require_alpn_match: PyRwLock::new(false),
post_handshake_auth: PyRwLock::new(false),
num_tickets: PyRwLock::new(2), // TLS 1.3 default
minimum_version: PyRwLock::new(min_version),
maximum_version: PyRwLock::new(max_version),
sni_callback: PyRwLock::new(None),
msg_callback: PyRwLock::new(None),
ecdh_curve: PyRwLock::new(None),
ca_cert_count: PyRwLock::new(0),
x509_cert_count: PyRwLock::new(0),
// Use the shared cache created above
client_session_cache: shared_session_cache,
rustls_session_store: rustls_client_store,
rustls_server_session_store: rustls::server::ServerSessionMemoryCache::new(
SSL_SESSION_CACHE_SIZE,
),
server_ticketer: rustls::crypto::aws_lc_rs::Ticketer::new()
.expect("Failed to create shared ticketer for TLS 1.2 session resumption"),
accept_count: AtomicUsize::new(0),
session_hits: AtomicUsize::new(0),
selected_ciphers: PyRwLock::new(None),
})
}
}
// SSLSocket - represents a TLS-wrapped socket
#[pyattr]
#[pyclass(name = "_SSLSocket", module = "ssl", traverse)]
#[derive(Debug, PyPayload)]
pub(crate) struct PySSLSocket {
// Underlying socket
sock: PyObjectRef,
// SSL context
context: PyRwLock<PyRef<PySSLContext>>,
// Server-side or client-side
#[pytraverse(skip)]
server_side: bool,
// Server hostname for SNI
#[pytraverse(skip)]
server_hostname: PyRwLock<Option<String>>,
// TLS connection state
#[pytraverse(skip)]
connection: PyMutex<Option<TlsConnection>>,
// Handshake completed flag
#[pytraverse(skip)]
handshake_done: PyMutex<bool>,
// Session was reused (for session resumption tracking)
#[pytraverse(skip)]
session_was_reused: PyMutex<bool>,
// Owner (SSLSocket instance that owns this _SSLSocket)
owner: PyRwLock<Option<PyObjectRef>>,
// Session for resumption
session: PyRwLock<Option<PyObjectRef>>,
// Verified certificate chain (built during verification)
#[allow(dead_code)]
#[pytraverse(skip)]
verified_chain: PyRwLock<Option<Vec<CertificateDer<'static>>>>,
// MemoryBIO mode (optional)
incoming_bio: Option<PyRef<PyMemoryBIO>>,
outgoing_bio: Option<PyRef<PyMemoryBIO>>,
// SNI certificate resolver state (for server-side only)
#[pytraverse(skip)]
sni_state: PyRwLock<Option<Arc<ParkingMutex<SniCertName>>>>,
// Pending context change (for SNI callback deferred handling)
pending_context: PyRwLock<Option<PyRef<PySSLContext>>>,
// Buffer to store ClientHello for connection recreation
#[pytraverse(skip)]
client_hello_buffer: PyMutex<Option<Vec<u8>>>,
// Shutdown state for tracking close-notify exchange
#[pytraverse(skip)]
shutdown_state: PyMutex<ShutdownState>,
// Pending TLS output buffer for non-blocking sockets
// Stores unsent TLS bytes when sock_send() would block
// This prevents data loss when write_tls() drains rustls' internal buffer
// but the socket cannot accept all the data immediately
#[pytraverse(skip)]
pub(crate) pending_tls_output: PyMutex<Vec<u8>>,
// Tracks bytes already buffered in rustls for the current write operation
// Prevents duplicate writes when retrying after WantWrite/WantRead
#[pytraverse(skip)]
pub(crate) write_buffered_len: PyMutex<usize>,
// Deferred client certificate verification error (for TLS 1.3)
// Stores error message if client cert verification failed during handshake
// Error is raised on first I/O operation after handshake
// Using Arc to share with the certificate verifier
#[pytraverse(skip)]
deferred_cert_error: Arc<ParkingRwLock<Option<String>>>,
}
// Shutdown state for tracking close-notify exchange
#[derive(Debug, Clone, Copy, PartialEq)]
enum ShutdownState {
NotStarted, // unwrap() not called yet
SentCloseNotify, // close-notify sent, waiting for peer's response
Completed, // unwrap() completed successfully
}
#[pyclass(with(Constructor), flags(BASETYPE))]
impl PySSLSocket {
// Check if this is BIO mode
pub(crate) fn is_bio_mode(&self) -> bool {
self.incoming_bio.is_some() && self.outgoing_bio.is_some()
}
// Get incoming BIO reference (for EOF checking)
pub(crate) fn incoming_bio(&self) -> Option<PyObjectRef> {
self.incoming_bio.as_ref().map(|bio| bio.clone().into())
}
// Check for deferred certificate verification errors (TLS 1.3)
// If an error exists, raise it and clear it from storage
fn check_deferred_cert_error(&self, vm: &VirtualMachine) -> PyResult<()> {
let error_opt = self.deferred_cert_error.read().clone();
if let Some(error_msg) = error_opt {
// Clear the error so it's only raised once
*self.deferred_cert_error.write() = None;
// Raise OSError with the stored error message
return Err(vm.new_os_error(error_msg));
}
Ok(())
}
// Get socket timeout as Duration
pub(crate) fn get_socket_timeout(&self, vm: &VirtualMachine) -> PyResult<Option<Duration>> {
if self.is_bio_mode() {
return Ok(None);
}
// Get timeout from socket
let timeout_obj = self.sock.get_attr("gettimeout", vm)?.call((), vm)?;
// timeout can be None (blocking), 0.0 (non-blocking), or positive float
if vm.is_none(&timeout_obj) {
// None means blocking forever
Ok(None)
} else {
let timeout_float: f64 = timeout_obj.try_into_value(vm)?;
if timeout_float <= 0.0 {
// 0 means non-blocking
Ok(Some(Duration::from_secs(0)))
} else {
// Positive timeout
Ok(Some(Duration::from_secs_f64(timeout_float)))
}
}
}
// Create and store a session object after successful handshake
fn create_session_after_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
// Only create session for client-side connections
if self.server_side {
return Ok(());
}
// Check if session already exists
let session_opt = self.session.read().clone();
if let Some(ref s) = session_opt {
if vm.is_none(s) {
} else {
return Ok(());
}
}
// Get server hostname
let server_name = self.server_hostname.read().clone();
// Try to get session data from context's session cache
// IMPORTANT: Acquire and release locks quickly to avoid deadlock
let context = self.context.read();
let session_cache_arc = context.client_session_cache.clone();
drop(context); // Release context lock ASAP
let (session_id, creation_time, lifetime) = if let Some(ref name) = server_name {
let key = name.as_bytes().to_vec();
// Clone the data we need while holding the lock, then immediately release
let session_data_opt = {
let cache_guard = session_cache_arc.read();
cache_guard.get(&key).cloned() // Clone Arc<PyMutex<SessionData>>
}; // Lock released here
if let Some(session_data_arc) = session_data_opt {
let data = session_data_arc.lock();
let result = (data.session_id.clone(), data.creation_time, data.lifetime);
drop(data); // Explicit unlock
result
} else {
// Create new session ID if not in cache
let time = std::time::SystemTime::now();
(generate_session_id_from_metadata(name, &time), time, 7200)
}
} else {
// No server name, use defaults
let time = std::time::SystemTime::now();
(vec![0; 16], time, 7200)
};
// Create a new SSLSession object with real metadata
let session = PySSLSession {
// Use dummy session data to indicate we have a ticket
// TLS 1.2+ always uses session tickets/resumption
session_data: vec![1], // Non-empty to indicate has_ticket=True
session_id,
creation_time,
lifetime,
};
let py_session = session.into_pyobject(vm);
*self.session.write() = Some(py_session);
Ok(())
}
// Complete handshake and create session
/// Track which CA certificate from capath was used to verify peer
///
/// This simulates lazy loading behavior: capath certificates
/// are only added to get_ca_certs() after they're actually used in a handshake.
fn track_used_ca_from_capath(&self) -> Result<(), String> {
// Extract capath_certs, releasing context lock quickly
let capath_certs = {
let context = self.context.read();
let certs = context.capath_certs_der.read();
if certs.is_empty() {
return Ok(());
}
certs.clone()
};
// Extract peer certificates, releasing connection lock quickly
let top_cert_der = {
let conn_guard = self.connection.lock();
let conn = conn_guard.as_ref().ok_or("No connection")?;
let peer_certs = conn.peer_certificates().ok_or("No peer certificates")?;
if peer_certs.is_empty() {
return Ok(());
}
peer_certs
.iter()
.map(|c| c.as_ref().to_vec())
.next_back()
.expect("is_empty checked above")
};
// Get the top certificate in the chain (closest to root)
// Note: Server usually doesn't send the root CA, so we check the last cert's issuer
let (_, top_cert) = x509_parser::parse_x509_certificate(&top_cert_der)
.map_err(|e| format!("Failed to parse top cert: {e}"))?;
let top_issuer = top_cert.issuer();
// Find matching CA in capath certs (skip unparseable certificates)
let matching_ca = capath_certs.iter().find_map(|ca_der| {
let (_, ca) = x509_parser::parse_x509_certificate(ca_der).ok()?;
// Check if this CA is self-signed (root CA) and matches the issuer
(ca.subject() == ca.issuer() && ca.subject() == top_issuer).then(|| ca_der.clone())
});
// Update ca_certs_der if we found a match
if let Some(ca_der) = matching_ca {
let context = self.context.read();
let mut ca_certs_der = context.ca_certs_der.write();
if !ca_certs_der.iter().any(|c| c == &ca_der) {
ca_certs_der.push(ca_der);
}
}
Ok(())
}
fn complete_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
*self.handshake_done.lock() = true;
// Check if session was resumed - get value and release lock immediately
let was_resumed = self
.connection
.lock()
.as_ref()
.map(|conn| conn.is_session_resumed())
.unwrap_or(false);
*self.session_was_reused.lock() = was_resumed;
// Update context session statistics if server-side
if self.server_side {
let context = self.context.read();
// Increment accept count for every successful server handshake
context.accept_count.fetch_add(1, Ordering::SeqCst);
// Increment hits count if session was resumed
if was_resumed {
context.session_hits.fetch_add(1, Ordering::SeqCst);
}
}
// Track CA certificate used during handshake (client-side only)
// This simulates lazy loading behavior for capath certificates
if !self.server_side {
// Don't fail handshake if tracking fails
let _ = self.track_used_ca_from_capath();
}
self.create_session_after_handshake(vm)?;
Ok(())
}
// Internal implementation with timeout control
pub(crate) fn sock_wait_for_io_impl(
&self,
kind: SelectKind,
vm: &VirtualMachine,
) -> PyResult<bool> {
if self.is_bio_mode() {
// BIO mode doesn't use select
return Ok(false);
}
// Get timeout
let timeout = self.get_socket_timeout(vm)?;
// Check for non-blocking mode (timeout = 0)
if let Some(t) = timeout
&& t.is_zero()
{
// Non-blocking mode - don't use select
return Ok(false);
}
// Use select with the effective timeout
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
let socket = py_socket
.sock()
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
let timed_out = sock_select(&socket, kind, timeout)
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
Ok(timed_out)
}
// Internal implementation with explicit timeout override
pub(crate) fn sock_wait_for_io_with_timeout(
&self,
kind: SelectKind,
timeout: Option<core::time::Duration>,
vm: &VirtualMachine,
) -> PyResult<bool> {
if self.is_bio_mode() {
// BIO mode doesn't use select
return Ok(false);
}
if let Some(t) = timeout
&& t.is_zero()
{
// Non-blocking mode - don't use select
return Ok(false);
}
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
let socket = py_socket
.sock()
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
let timed_out = sock_select(&socket, kind, timeout)
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
Ok(timed_out)
}
// SNI (Server Name Indication) Helper Methods:
// These methods support the server-side handshake SNI callback mechanism
/// Check if this is the first read during handshake (for SNI callback)
/// Returns true if we haven't processed ClientHello yet, regardless of SNI presence
pub(crate) fn is_first_sni_read(&self) -> bool {
self.client_hello_buffer.lock().is_none()
}
/// Check if SNI callback is configured
pub(crate) fn has_sni_callback(&self) -> bool {
// Nested read locks are safe
self.context.read().sni_callback.read().is_some()
}
/// Save ClientHello data from PyObjectRef for potential connection recreation
pub(crate) fn save_client_hello_from_bytes(&self, bytes_data: &[u8]) {
*self.client_hello_buffer.lock() = Some(bytes_data.to_vec());
}
/// Get the extracted SNI name from resolver
pub(crate) fn get_extracted_sni_name(&self) -> Option<String> {
// Clone the Arc option to avoid nested lock (sni_state.read -> arc.lock)
let sni_state_opt = self.sni_state.read().clone();
sni_state_opt.as_ref().and_then(|arc| arc.lock().1.clone())
}
/// Invoke the Python SNI callback
pub(crate) fn invoke_sni_callback(
&self,
sni_name: Option<&str>,
vm: &VirtualMachine,
) -> PyResult<()> {
let callback = self
.context
.read()
.sni_callback
.read()
.clone()
.ok_or_else(|| vm.new_value_error("SNI callback not set"))?;
let ssl_sock = self.owner.read().clone().unwrap_or(vm.ctx.none());
let server_name_py: PyObjectRef = match sni_name {
Some(name) => vm.ctx.new_str(name.to_string()).into(),
None => vm.ctx.none(),
};
let initial_context: PyObjectRef = self.context.read().clone().into();
// catches exceptions from the callback and reports them as unraisable
let result = match callback.call((ssl_sock, server_name_py, initial_context), vm) {
Ok(result) => result,
Err(exc) => {
vm.run_unraisable(
exc,
Some("in ssl servername callback".to_owned()),
callback.clone(),
);
// Return SSL error like SSL_TLSEXT_ERR_ALERT_FATAL
let ssl_exc: PyBaseExceptionRef = vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"SNI callback raised exception",
)
.upcast();
let _ = ssl_exc.as_object().set_attr(
"reason",
vm.ctx.new_str("TLSV1_ALERT_INTERNAL_ERROR"),
vm,
);
return Err(ssl_exc);
}
};
// Check return value type (must be None or integer)
if !vm.is_none(&result) {
// Try to convert to integer
if result.try_to_value::<i32>(vm).is_err() {
// Type conversion failed - raise TypeError as unraisable
let type_error = vm.new_type_error(format!(
"servername callback must return None or an integer, not '{}'",
result.class().name()
));
vm.run_unraisable(type_error, None, result.clone());
// Return SSL error with reason set to TLSV1_ALERT_INTERNAL_ERROR
//
// RUSTLS API LIMITATION:
// We cannot send a TLS InternalError alert to the client here because:
// 1. Rustls does not provide a public API like send_fatal_alert()
// 2. This method is called AFTER dropping the connection lock (to prevent deadlock)
// 3. By the time we detect the error, the connection is no longer available
//
// CPython/OpenSSL behavior:
// - SNI callback runs inside SSL_do_handshake with connection active
// - Sets *al = SSL_AD_INTERNAL_ERROR
// - OpenSSL automatically sends alert before returning
//
// RustPython/Rustls behavior:
// - SNI callback runs after dropping connection lock (deadlock prevention)
// - Exception has _reason='TLSV1_ALERT_INTERNAL_ERROR' for error reporting
// - TCP connection closes without sending TLS alert to client
//
// If rustls adds send_fatal_alert() API in the future, we should:
// - Re-acquire connection lock after callback
// - Call: connection.send_fatal_alert(AlertDescription::InternalError)
// - Then close connection
let exc: PyBaseExceptionRef = vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"SNI callback returned invalid type",
)
.upcast();
let _ = exc.as_object().set_attr(
"reason",
vm.ctx.new_str("TLSV1_ALERT_INTERNAL_ERROR"),
vm,
);
return Err(exc);
}
}
Ok(())
}
// Helper to call socket methods, bypassing any SSL wrapper
pub(crate) fn sock_recv(&self, size: usize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// In BIO mode, read from incoming BIO (flags not supported)
if let Some(ref bio) = self.incoming_bio {
let bio_obj: PyObjectRef = bio.clone().into();
let read_method = bio_obj.get_attr("read", vm)?;
return read_method.call((vm.ctx.new_int(size),), vm);
}
// Normal socket mode
let socket_mod = vm.import("socket", 0)?;
let socket_class = socket_mod.get_attr("socket", vm)?;
// Call socket.socket.recv(self.sock, size, flags)
let recv_method = socket_class.get_attr("recv", vm)?;
recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm)
}
/// Peek at socket data without consuming it (MSG_PEEK).
/// Used during TLS shutdown to avoid consuming post-TLS cleartext data.
pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let socket_mod = vm.import("socket", 0)?;
let socket_class = socket_mod.get_attr("socket", vm)?;
let recv_method = socket_class.get_attr("recv", vm)?;
let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?;
recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm)
}
/// Socket send - just sends data, caller must handle pending flush
/// Use flush_pending_tls_output before this if ordering is important
pub(crate) fn sock_send(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// In BIO mode, write to outgoing BIO
if let Some(ref bio) = self.outgoing_bio {
let bio_obj: PyObjectRef = bio.clone().into();
let write_method = bio_obj.get_attr("write", vm)?;
return write_method.call((vm.ctx.new_bytes(data.to_vec()),), vm);
}
// Normal socket mode
let socket_mod = vm.import("socket", 0)?;
let socket_class = socket_mod.get_attr("socket", vm)?;
// Call socket.socket.send(self.sock, data)
let send_method = socket_class.get_attr("send", vm)?;
send_method.call((self.sock.clone(), vm.ctx.new_bytes(data.to_vec())), vm)
}
/// Flush any pending TLS output data to the socket
/// Optional deadline parameter allows respecting a read deadline during flush
pub(crate) fn flush_pending_tls_output(
&self,
vm: &VirtualMachine,
deadline: Option<std::time::Instant>,
) -> PyResult<()> {
let mut pending = self.pending_tls_output.lock();
if pending.is_empty() {
return Ok(());
}
let socket_timeout = self.get_socket_timeout(vm)?;
let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false);
let mut sent_total = 0;
while sent_total < pending.len() {
// Calculate timeout: use deadline if provided, otherwise use socket timeout
let timeout_to_use = if let Some(dl) = deadline {
let now = std::time::Instant::now();
if now >= dl {
// Deadline already passed
*pending = pending[sent_total..].to_vec();
return Err(
timeout_error_msg(vm, "The operation timed out".to_string()).upcast()
);
}
Some(dl - now)
} else {
socket_timeout
};
// Use sock_select directly with calculated timeout
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
let socket = py_socket
.sock()
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
let timed_out = sock_select(&socket, SelectKind::Write, timeout_to_use)
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
if timed_out {
// Keep unsent data in pending buffer
*pending = pending[sent_total..].to_vec();
if is_non_blocking {
return Err(create_ssl_want_write_error(vm).upcast());
}
return Err(
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
);
}
match self.sock_send(&pending[sent_total..], vm) {
Ok(result) => {
let sent: usize = result.try_to_value::<isize>(vm)?.try_into().unwrap_or(0);
if sent == 0 {
if is_non_blocking {
// Keep unsent data in pending buffer
*pending = pending[sent_total..].to_vec();
return Err(create_ssl_want_write_error(vm).upcast());
}
// Socket said ready but sent 0 bytes - retry
continue;
}
sent_total += sent;
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
if is_non_blocking {
// Keep unsent data in pending buffer
*pending = pending[sent_total..].to_vec();
return Err(create_ssl_want_write_error(vm).upcast());
}
continue;
}
// Keep unsent data in pending buffer for other errors too
*pending = pending[sent_total..].to_vec();
return Err(e);
}
}
}
// All data sent successfully
pending.clear();
Ok(())
}
/// Send TLS output data to socket, saving unsent bytes to pending buffer
/// This prevents data loss when rustls' write_tls() drains its internal buffer
/// but the socket cannot accept all the data immediately
fn send_tls_output(&self, buf: Vec<u8>, vm: &VirtualMachine) -> PyResult<()> {
if buf.is_empty() {
return Ok(());
}
let timeout = self.get_socket_timeout(vm)?;
let is_non_blocking = timeout.map(|t| t.is_zero()).unwrap_or(false);
let mut sent_total = 0;
while sent_total < buf.len() {
let timed_out = self.sock_wait_for_io_impl(SelectKind::Write, vm)?;
if timed_out {
// Save unsent data to pending buffer
self.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
);
}
match self.sock_send(&buf[sent_total..], vm) {
Ok(result) => {
let sent: usize = result.try_to_value::<isize>(vm)?.try_into().unwrap_or(0);
if sent == 0 {
if is_non_blocking {
// Save unsent data to pending buffer
self.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(create_ssl_want_write_error(vm).upcast());
}
continue;
}
sent_total += sent;
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
if is_non_blocking {
// Save unsent data to pending buffer
self.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(create_ssl_want_write_error(vm).upcast());
}
continue;
}
// Save unsent data for other errors too
self.pending_tls_output
.lock()
.extend_from_slice(&buf[sent_total..]);
return Err(e);
}
}
}
Ok(())
}
/// Flush all pending TLS output data, respecting socket timeout
/// Used during handshake completion and shutdown() to ensure all data is sent
pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> {
// Get socket timeout to respect during flush
let timeout = self.get_socket_timeout(vm)?;
if timeout.map(|t| t.is_zero()).unwrap_or(false) {
return self.flush_pending_tls_output(vm, None);
}
loop {
let pending_data = {
let pending = self.pending_tls_output.lock();
if pending.is_empty() {
return Ok(());
}
pending.clone()
};
// Wait for socket to be writable, respecting socket timeout
let py_socket: PyRef<PySocket> = self.sock.clone().try_into_value(vm)?;
let socket = py_socket
.sock()
.map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?;
let timed_out = sock_select(&socket, SelectKind::Write, timeout)
.map_err(|e| vm.new_os_error(format!("select failed: {e}")))?;
if timed_out {
return Err(
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
);
}
// Try to send pending data (use raw to avoid recursion)
match self.sock_send(&pending_data, vm) {
Ok(result) => {
let sent: usize = result.try_to_value::<isize>(vm)?.try_into().unwrap_or(0);
if sent > 0 {
let mut pending = self.pending_tls_output.lock();
pending.drain(..sent);
}
// If sent == 0, loop will retry with sock_select
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
continue;
}
return Err(e);
}
}
}
}
#[pymethod]
fn __repr__(&self) -> String {
"<SSLSocket>".to_string()
}
// Helper function to convert Python PROTO_* constants to rustls versions
fn get_rustls_versions(
minimum: i32,
maximum: i32,
options: i32,
) -> &'static [&'static rustls::SupportedProtocolVersion] {
// Rustls only supports TLS 1.2 and 1.3
// PROTO_TLSv1_2 = 0x0303, PROTO_TLSv1_3 = 0x0304
// PROTO_MINIMUM_SUPPORTED = -2, PROTO_MAXIMUM_SUPPORTED = -1
// If minimum and maximum are 0, use default (both TLS 1.2 and 1.3)
// Static arrays for single-version configurations
static TLS12_ONLY: &[&rustls::SupportedProtocolVersion] = &[&TLS12];
static TLS13_ONLY: &[&rustls::SupportedProtocolVersion] = &[&TLS13];
// Normalize special values: -2 (MINIMUM_SUPPORTED) → TLS 1.2, -1 (MAXIMUM_SUPPORTED) → TLS 1.3
let min = if minimum == -2 {
PROTO_TLSv1_2
} else {
minimum
};
let max = if maximum == -1 {
PROTO_TLSv1_3
} else {
maximum
};
// Check if versions are disabled by options
let tls12_disabled = (options & OP_NO_TLSv1_2) != 0;
let tls13_disabled = (options & OP_NO_TLSv1_3) != 0;
let want_tls12 = (min == 0 || min <= PROTO_TLSv1_2)
&& (max == 0 || max >= PROTO_TLSv1_2)
&& !tls12_disabled;
let want_tls13 = (min == 0 || min <= PROTO_TLSv1_3)
&& (max == 0 || max >= PROTO_TLSv1_3)
&& !tls13_disabled;
match (want_tls12, want_tls13) {
(true, true) => rustls::DEFAULT_VERSIONS, // Both TLS 1.2 and 1.3
(true, false) => TLS12_ONLY, // Only TLS 1.2
(false, true) => TLS13_ONLY, // Only TLS 1.3
(false, false) => rustls::DEFAULT_VERSIONS, // Fallback to default
}
}
/// Helper: Prepare TLS versions from context settings
fn prepare_tls_versions(&self) -> &'static [&'static rustls::SupportedProtocolVersion] {
let ctx = self.context.read();
let min_ver = *ctx.minimum_version.read();
let max_ver = *ctx.maximum_version.read();
let options = *ctx.options.read();
Self::get_rustls_versions(min_ver, max_ver, options)
}
/// Helper: Prepare KX groups (ECDH curve) from context settings
fn prepare_kx_groups(
&self,
vm: &VirtualMachine,
) -> PyResult<Option<Vec<&'static dyn SupportedKxGroup>>> {
let ctx = self.context.read();
let ecdh_curve = ctx.ecdh_curve.read().clone();
drop(ctx);
if let Some(ref curve_name) = ecdh_curve {
match curve_name_to_kx_group(curve_name) {
Ok(groups) => Ok(Some(groups)),
Err(e) => Err(vm.new_value_error(format!("Failed to set ECDH curve: {e}"))),
}
} else {
Ok(None)
}
}
/// Helper: Prepare all common protocol settings (versions, KX groups, ciphers, ALPN)
fn prepare_protocol_settings(&self, vm: &VirtualMachine) -> PyResult<ProtocolSettings> {
let ctx = self.context.read();
let versions = self.prepare_tls_versions();
let kx_groups = self.prepare_kx_groups(vm)?;
let cipher_suites = ctx.selected_ciphers.read().clone();
let alpn_protocols = ctx.alpn_protocols.read().clone();
Ok(ProtocolSettings {
versions,
kx_groups,
cipher_suites,
alpn_protocols,
})
}
/// Initialize server-side TLS connection with configuration
///
/// This method handles all server-side setup including:
/// - Certificate and key validation
/// - Client authentication configuration
/// - SNI (Server Name Indication) setup
/// - ALPN protocol negotiation
/// - Session resumption configuration
///
/// Returns the configured ServerConnection.
fn initialize_server_connection(
&self,
conn_guard: &mut Option<TlsConnection>,
vm: &VirtualMachine,
) -> PyResult<()> {
let ctx = self.context.read();
let cert_keys = ctx.cert_keys.read();
if cert_keys.is_empty() {
return Err(vm.new_value_error(
"Server-side connection requires certificate and key (use load_cert_chain)",
));
}
// Clone cert_keys for use in config
// PrivateKeyDer doesn't implement Clone, use clone_key()
let cert_keys_clone: Vec<CertKeyPair> = cert_keys
.iter()
.map(|(ck, pk)| (ck.clone(), pk.clone_key()))
.collect();
drop(cert_keys);
// Prepare common protocol settings (TLS versions, ECDH curve, cipher suites, ALPN)
let protocol_settings = self.prepare_protocol_settings(vm)?;
let min_ver = *ctx.minimum_version.read();
// Check if client certificate verification is required
let verify_mode = *ctx.verify_mode.read();
let root_store = ctx.root_certs.read();
let pha_enabled = *ctx.post_handshake_auth.read();
// Check if TLS 1.3 is being used
let is_tls13 = min_ver >= PROTO_TLSv1_3;
// For TLS 1.3: always use deferred validation for client certificates
// For TLS 1.2: use immediate validation during handshake
let use_deferred_validation = is_tls13
&& !pha_enabled
&& (verify_mode == CERT_REQUIRED || verify_mode == CERT_OPTIONAL);
// For TLS 1.3 + PHA: if PHA is enabled, don't request cert in initial handshake
// The certificate will be requested later via verify_client_post_handshake()
let request_initial_cert = if pha_enabled {
// PHA enabled: don't request cert initially (will use PHA later)
false
} else if verify_mode == CERT_REQUIRED || verify_mode == CERT_OPTIONAL {
// PHA not enabled or TLS 1.2: request cert in initial handshake
true
} else {
// CERT_NONE
false
};
// Check if SNI callback is set
let sni_callback = ctx.sni_callback.read().clone();
let use_sni_resolver = sni_callback.is_some();
// Create SNI state if needed (to be stored in PySSLSocket later)
// For SNI, use the first cert_key pair as the initial certificate
let sni_state: Option<Arc<ParkingMutex<SniCertName>>> = if use_sni_resolver {
// Use first cert_key as initial certificate for SNI
// Extract CertifiedKey from tuple
let (first_cert_key, _) = &cert_keys_clone[0];
let first_cert_key = first_cert_key.clone();
// Check if we already have existing SNI state (from previous connection)
let existing_sni_state = self.sni_state.read().clone();
if let Some(sni_state_arc) = existing_sni_state {
// Reuse existing Arc and update its contents
// This is crucial: rustls SniCertResolver holds references to this Arc
let mut state = sni_state_arc.lock();
state.0 = first_cert_key;
state.1 = None; // Reset SNI name for new connection
drop(state);
// Return the existing Arc (not a new one!)
Some(sni_state_arc)
} else {
// First connection: create new SNI state
Some(Arc::new(ParkingMutex::new((first_cert_key, None))))
}
} else {
None
};
// Determine which cert resolver to use
// Priority: SNI > Multi-cert/Single-cert via MultiCertResolver
let cert_resolver: Option<Arc<dyn ResolvesServerCert>> = if use_sni_resolver {
// SNI takes precedence - use first cert_key for initial setup
sni_state.as_ref().map(|sni_state_arc| {
Arc::new(SniCertResolver {
sni_state: sni_state_arc.clone(),
}) as Arc<dyn ResolvesServerCert>
})
} else {
// Use MultiCertResolver for all cases (single or multiple certs)
// Extract CertifiedKey from tuples for MultiCertResolver
let cert_keys_only: Vec<Arc<CertifiedKey>> =
cert_keys_clone.iter().map(|(ck, _)| ck.clone()).collect();
Some(Arc::new(MultiCertResolver::new(cert_keys_only)))
};
// Extract cert_chain and private_key from first cert_key
//
// Note: Since we always use cert_resolver now, these values won't actually be used
// by create_server_config. But we still need to provide them for the API signature.
let (first_cert_key, _) = &cert_keys_clone[0];
let certs_clone = first_cert_key.cert.clone();
// Provide a dummy key since cert_resolver will handle cert selection
let key_clone = PrivateKeyDer::Pkcs8(Vec::new().into());
// Get shared server session storage and ticketer from context
let server_session_storage = ctx.rustls_server_session_store.clone();
let server_ticketer = ctx.server_ticketer.clone();
// Build server config using compat helper
let config_options = ServerConfigOptions {
protocol_settings,
cert_chain: certs_clone,
private_key: key_clone,
root_store: if request_initial_cert {
Some(root_store.clone())
} else {
None
},
request_client_cert: request_initial_cert,
use_deferred_validation,
cert_resolver,
deferred_cert_error: if use_deferred_validation {
Some(self.deferred_cert_error.clone())
} else {
None
},
session_storage: Some(server_session_storage),
ticketer: Some(server_ticketer),
};
drop(root_store);
// Check if we have a cached ServerConfig
let cached_config_arc = ctx.server_config.read().clone();
drop(ctx);
let config_arc = if let Some(cached) = cached_config_arc {
// Don't use cache when SNI is enabled, because each connection needs
// a fresh SniCertResolver with the correct Arc references
if use_sni_resolver {
let config =
create_server_config(config_options).map_err(|e| vm.new_value_error(e))?;
Arc::new(config)
} else {
cached
}
} else {
let config =
create_server_config(config_options).map_err(|e| vm.new_value_error(e))?;
let config_arc = Arc::new(config);
// Cache the ServerConfig for future connections
let ctx = self.context.read();
*ctx.server_config.write() = Some(config_arc.clone());
drop(ctx);
config_arc
};
let conn = ServerConnection::new(config_arc).map_err(|e| {
vm.new_value_error(format!("Failed to create server connection: {e}"))
})?;
*conn_guard = Some(TlsConnection::Server(conn));
// If ClientHello buffer exists (from SNI callback), re-inject it
if let Some(ref hello_data) = *self.client_hello_buffer.lock()
&& let Some(TlsConnection::Server(ref mut server)) = *conn_guard
{
let mut cursor = std::io::Cursor::new(hello_data.as_slice());
let _ = server.read_tls(&mut cursor);
// Process the re-injected ClientHello
let _ = server.process_new_packets();
// DON'T clear buffer - keep it to prevent callback from being invoked again
// The buffer being non-empty signals that SNI callback was already processed
}
// Store SNI state if we're using SNI resolver
if let Some(sni_state_arc) = sni_state {
*self.sni_state.write() = Some(sni_state_arc);
}
Ok(())
}
#[pymethod]
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
// Check if handshake already done
if *self.handshake_done.lock() {
return Ok(());
}
let mut conn_guard = self.connection.lock();
// Initialize connection if not already done
if conn_guard.is_none() {
// Check for pending context change (from SNI callback)
if let Some(new_ctx) = self.pending_context.write().take() {
*self.context.write() = new_ctx;
}
if self.server_side {
// Server-side connection - delegate to helper method
self.initialize_server_connection(&mut conn_guard, vm)?;
} else {
// Client-side connection
let ctx = self.context.read();
// Prepare common protocol settings (TLS versions, ECDH curve, cipher suites, ALPN)
let protocol_settings = self.prepare_protocol_settings(vm)?;
// Clone values we need before building config
let verify_mode = *ctx.verify_mode.read();
let root_store_clone = ctx.root_certs.read().clone();
let ca_certs_der_clone = ctx.ca_certs_der.read().clone();
// For client mTLS: extract cert_chain and private_key from first cert_key (if any)
// Now we store both CertifiedKey and PrivateKeyDer as tuple
let cert_keys_guard = ctx.cert_keys.read();
let (cert_chain_clone, private_key_opt) = if !cert_keys_guard.is_empty() {
let (first_cert_key, private_key) = &cert_keys_guard[0];
let certs = first_cert_key.cert.clone();
(certs, Some(private_key.clone_key()))
} else {
(Vec::new(), None)
};
drop(cert_keys_guard);
let check_hostname = *ctx.check_hostname.read();
let verify_flags = *ctx.verify_flags.read();
// Get session store before dropping ctx
let session_store = ctx.rustls_session_store.clone();
// Get CRLs for revocation checking
let crls_clone = ctx.crls.read().clone();
// Drop ctx early to avoid borrow conflicts
drop(ctx);
// Build client config using compat helper
let config_options = ClientConfigOptions {
protocol_settings,
root_store: if verify_mode != CERT_NONE {
Some(root_store_clone)
} else {
None
},
ca_certs_der: ca_certs_der_clone,
cert_chain: if !cert_chain_clone.is_empty() {
Some(cert_chain_clone)
} else {
None
},
private_key: private_key_opt,
verify_server_cert: verify_mode != CERT_NONE,
check_hostname,
verify_flags,
session_store: Some(session_store),
crls: crls_clone,
};
let config =
create_client_config(config_options).map_err(|e| vm.new_value_error(e))?;
// Parse server name for SNI
// Convert to ServerName
use rustls::pki_types::ServerName;
let hostname_opt = self.server_hostname.read().clone();
let server_name = if let Some(ref hostname) = hostname_opt {
// Use the provided hostname for SNI
ServerName::try_from(hostname.clone()).map_err(|e| {
vm.new_value_error(format!("Invalid server hostname: {e:?}"))
})?
} else {
// When server_hostname=None, use an IP address to suppress SNI
// no hostname = no SNI extension
ServerName::IpAddress(
core::net::IpAddr::V4(core::net::Ipv4Addr::new(127, 0, 0, 1)).into(),
)
};
let conn = ClientConnection::new(Arc::new(config), server_name.clone())
.map_err(|e| {
vm.new_value_error(format!("Failed to create client connection: {e}"))
})?;
*conn_guard = Some(TlsConnection::Client(conn));
}
}
// Perform the actual handshake by exchanging data with the socket/BIO
let conn = conn_guard.as_mut().expect("unreachable");
let is_client = matches!(conn, TlsConnection::Client(_));
let handshake_result = ssl_do_handshake(conn, self, vm);
drop(conn_guard);
if is_client {
// CLIENT is simple - no SNI callback handling needed
handshake_result.map_err(|e| e.into_py_err(vm))?;
self.complete_handshake(vm)?;
Ok(())
} else {
// Use OpenSSL-compatible handshake for server
// Handle SNI callback restart
match handshake_result {
Ok(()) => {
// Handshake completed successfully
self.complete_handshake(vm)?;
Ok(())
}
Err(SslError::SniCallbackRestart) => {
// SNI detected - need to call callback and recreate connection
// Get the SNI name that was extracted (may be None if client didn't send SNI)
let sni_name = self.get_extracted_sni_name();
// Now safe to call Python callback (no locks held)
self.invoke_sni_callback(sni_name.as_deref(), vm)?;
// Clear connection to trigger recreation
*self.connection.lock() = None;
// Recursively call do_handshake to recreate with new context
self.do_handshake(vm)
}
Err(e) => {
// Other errors - convert to Python exception
Err(e.into_py_err(vm))
}
}
}
}
#[pymethod]
fn read(
&self,
len: OptionalArg<isize>,
buffer: OptionalArg<ArgMemoryBuffer>,
vm: &VirtualMachine,
) -> PyResult {
// Convert len to usize, defaulting to 1024 if not provided
// -1 means read all available data (treat as large buffer size)
let len_val = len.unwrap_or(PEM_BUFSIZE as isize);
let mut len = if len_val == -1 {
// -1 is only valid when a buffer is provided
match &buffer {
OptionalArg::Present(buf_arg) => buf_arg.len(),
OptionalArg::Missing => {
return Err(vm.new_value_error("negative read length"));
}
}
} else if len_val < 0 {
return Err(vm.new_value_error("negative read length"));
} else {
len_val as usize
};
// if buffer is provided, limit len to buffer size
if let OptionalArg::Present(buf_arg) = &buffer {
let buf_len = buf_arg.len();
if len_val <= 0 || len > buf_len {
len = buf_len;
}
}
// return empty bytes immediately for len=0
if len == 0 {
return match buffer {
OptionalArg::Present(_) => Ok(vm.ctx.new_int(0).into()),
OptionalArg::Missing => Ok(vm.ctx.new_bytes(vec![]).into()),
};
}
// Ensure handshake is done - if not, complete it first
// This matches OpenSSL behavior where SSL_read() auto-completes handshake
if !*self.handshake_done.lock() {
self.do_handshake(vm)?;
}
// Check if connection has been shut down
// Only block after shutdown is COMPLETED, not during shutdown process
let shutdown_state = *self.shutdown_state.lock();
if shutdown_state == ShutdownState::Completed {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"cannot read after shutdown",
)
.upcast());
}
// Helper function to handle return value based on buffer presence
let return_data = |data: Vec<u8>,
buffer_arg: &OptionalArg<ArgMemoryBuffer>,
vm: &VirtualMachine|
-> PyResult<PyObjectRef> {
match buffer_arg {
OptionalArg::Present(buf_arg) => {
// Write into buffer and return number of bytes written
let n = data.len();
if n > 0 {
let mut buf = buf_arg.borrow_buf_mut();
let buf_slice = &mut *buf;
let copy_len = n.min(buf_slice.len());
buf_slice[..copy_len].copy_from_slice(&data[..copy_len]);
}
Ok(vm.ctx.new_int(n).into())
}
OptionalArg::Missing => {
// Return bytes object
Ok(vm.ctx.new_bytes(data).into())
}
}
};
// Use compat layer for unified read logic with proper EOF handling
// This matches SSL_read_ex() approach
let mut buf = vec![0u8; len];
let read_result = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
match read_result {
Ok(n) => {
// Check for deferred certificate verification errors (TLS 1.3)
// Must be checked AFTER ssl_read, as the error is set during I/O
self.check_deferred_cert_error(vm)?;
buf.truncate(n);
return_data(buf, &buffer, vm)
}
Err(crate::ssl::compat::SslError::Eof) => {
// If plaintext is still buffered, return it before EOF.
let pending = {
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(conn) => conn,
None => return Err(create_ssl_eof_error(vm).upcast()),
};
use std::io::BufRead;
let mut reader = conn.reader();
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
};
if pending > 0 {
let mut buf = vec![0u8; pending.min(len)];
let read_retry = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
if let Ok(n) = read_retry {
buf.truncate(n);
return return_data(buf, &buffer, vm);
}
}
// EOF occurred in violation of protocol (unexpected closure)
Err(create_ssl_eof_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::ZeroReturn) => {
// If plaintext is still buffered, return it before clean EOF.
let pending = {
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(conn) => conn,
None => return Err(create_ssl_zero_return_error(vm).upcast()),
};
use std::io::BufRead;
let mut reader = conn.reader();
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
};
if pending > 0 {
let mut buf = vec![0u8; pending.min(len)];
let read_retry = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
};
if let Ok(n) = read_retry {
buf.truncate(n);
return return_data(buf, &buffer, vm);
}
}
// Clean closure with close_notify
// CPython behavior depends on whether we've sent our close_notify:
// - If we've already sent close_notify (unwrap was called): raise SSLZeroReturnError
// - If we haven't sent close_notify yet: return empty bytes
let our_shutdown_state = *self.shutdown_state.lock();
if our_shutdown_state == ShutdownState::SentCloseNotify
|| our_shutdown_state == ShutdownState::Completed
{
// We already sent close_notify, now receiving peer's → SSLZeroReturnError
Err(create_ssl_zero_return_error(vm).upcast())
} else {
// We haven't sent close_notify yet → return empty bytes
return_data(vec![], &buffer, vm)
}
}
Err(crate::ssl::compat::SslError::WantRead) => {
// Non-blocking mode: would block
Err(create_ssl_want_read_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::WantWrite) => {
// Non-blocking mode: would block on write
Err(create_ssl_want_write_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::Timeout(msg)) => {
Err(timeout_error_msg(vm, msg).upcast())
}
Err(crate::ssl::compat::SslError::Py(e)) => {
// Python exception - pass through
Err(e)
}
Err(e) => {
// Other SSL errors
Err(e.into_py_err(vm))
}
}
}
#[pymethod]
fn pending(&self) -> PyResult<usize> {
// Returns the number of already decrypted bytes available for read
// This is critical for asyncore's readable() method which checks socket.pending() > 0
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(c) => c,
None => return Ok(0), // No connection established yet
};
// Use rustls Reader's fill_buf() to check buffered plaintext
// fill_buf() returns a reference to buffered data without consuming it
// This matches OpenSSL's SSL_pending() behavior
use std::io::BufRead;
let mut reader = conn.reader();
match reader.fill_buf() {
Ok(buf) => Ok(buf.len()),
Err(_) => {
// WouldBlock or other errors mean no data available
// Return 0 like OpenSSL does when buffer is empty
Ok(0)
}
}
}
#[pymethod]
fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
let data_bytes = data.borrow_buf();
let data_len = data_bytes.len();
if data_len == 0 {
return Ok(0);
}
// Ensure handshake is done (SSL_write auto-completes handshake)
if !*self.handshake_done.lock() {
self.do_handshake(vm)?;
}
// Check shutdown state
// Only block after shutdown is COMPLETED, not during shutdown process
if *self.shutdown_state.lock() == ShutdownState::Completed {
return Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"cannot write after shutdown",
)
.upcast());
}
// Call ssl_write (matches CPython's SSL_write_ex loop)
let result = {
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm)
};
match result {
Ok(n) => {
self.check_deferred_cert_error(vm)?;
Ok(n)
}
Err(crate::ssl::compat::SslError::WantRead) => {
Err(create_ssl_want_read_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::WantWrite) => {
Err(create_ssl_want_write_error(vm).upcast())
}
Err(crate::ssl::compat::SslError::Timeout(msg)) => {
Err(timeout_error_msg(vm, msg).upcast())
}
Err(e) => Err(e.into_py_err(vm)),
}
}
#[pymethod]
fn getpeercert(
&self,
args: GetCertArgs,
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
let binary = args.binary_form.unwrap_or(false);
// Check if handshake is complete
if !*self.handshake_done.lock() {
return Err(vm.new_value_error("handshake not done yet"));
}
// Extract DER bytes from connection, releasing lock quickly
let der_bytes = {
let conn_guard = self.connection.lock();
let conn = conn_guard
.as_ref()
.ok_or_else(|| vm.new_value_error("No TLS connection established"))?;
let Some(peer_certificates) = conn.peer_certificates() else {
return Ok(None);
};
let cert = peer_certificates
.first()
.ok_or_else(|| vm.new_value_error("No peer certificate available"))?;
cert.as_ref().to_vec()
};
if binary {
// Return DER-encoded certificate as bytes
return Ok(Some(vm.ctx.new_bytes(der_bytes).into()));
}
// Dictionary mode: check verify_mode
let verify_mode = *self.context.read().verify_mode.read();
if verify_mode == CERT_NONE {
// Return empty dict when CERT_NONE
return Ok(Some(vm.ctx.new_dict().into()));
}
// Parse DER certificate and convert to dict (outside lock)
let (_, cert) = x509_parser::parse_x509_certificate(&der_bytes)
.map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?;
cert::cert_to_dict(vm, &cert).map(Some)
}
#[pymethod]
fn cipher(&self) -> Option<(String, String, i32)> {
// Extract cipher suite, releasing lock quickly
let suite = {
let conn_guard = self.connection.lock();
conn_guard.as_ref()?.negotiated_cipher_suite()?
};
// Extract cipher information outside the lock
let cipher_info = extract_cipher_info(&suite);
// Note: returns a 3-tuple (name, protocol_version, bits)
// The 'description' field is part of get_ciphers() output, not cipher()
Some((
cipher_info.name,
cipher_info.protocol.to_string(),
cipher_info.bits,
))
}
#[pymethod]
fn version(&self) -> Option<String> {
// Extract cipher suite, releasing lock quickly
let suite = {
let conn_guard = self.connection.lock();
conn_guard.as_ref()?.negotiated_cipher_suite()?
};
// Convert to string outside the lock
let version_str = match suite.version().version {
rustls::ProtocolVersion::TLSv1_2 => "TLSv1.2",
rustls::ProtocolVersion::TLSv1_3 => "TLSv1.3",
_ => "Unknown",
};
Some(version_str.to_string())
}
#[pymethod]
fn selected_alpn_protocol(&self) -> Option<String> {
let conn_guard = self.connection.lock();
let conn = conn_guard.as_ref()?;
let alpn_bytes = conn.alpn_protocol()?;
// Null byte protocol (vec![0u8]) means no actual ALPN match (fallback protocol)
if alpn_bytes.is_empty() || alpn_bytes == [0u8] {
return None;
}
// Convert bytes to string
String::from_utf8(alpn_bytes.to_vec()).ok()
}
#[pymethod]
fn selected_npn_protocol(&self) -> Option<String> {
// NPN (Next Protocol Negotiation) is the predecessor to ALPN
// It was deprecated in favor of ALPN (RFC 7301)
// Rustls doesn't support NPN, only ALPN
// Return None to indicate NPN is not supported
None
}
#[pygetset]
fn owner(&self) -> Option<PyObjectRef> {
self.owner.read().clone()
}
#[pygetset(setter)]
fn set_owner(&self, owner: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> {
*self.owner.write() = Some(owner);
Ok(())
}
#[pygetset]
fn server_side(&self) -> bool {
self.server_side
}
#[pygetset]
fn context(&self) -> PyRef<PySSLContext> {
self.context.read().clone()
}
#[pygetset(setter)]
fn set_context(&self, value: PyRef<PySSLContext>, _vm: &VirtualMachine) -> PyResult<()> {
// Update context reference immediately
// SSL_set_SSL_CTX allows context changes at any time,
// even after handshake completion
*self.context.write() = value;
// Clear pending context as we've applied the change
*self.pending_context.write() = None;
Ok(())
}
#[pygetset]
fn server_hostname(&self) -> Option<String> {
self.server_hostname.read().clone()
}
#[pygetset(setter)]
fn set_server_hostname(
&self,
value: Option<PyStrRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
// Check if handshake is already done
if *self.handshake_done.lock() {
return Err(
vm.new_value_error("Cannot set server_hostname on socket after handshake")
);
}
// Validate hostname
if let Some(hostname_str) = &value {
validate_hostname(hostname_str.as_str(), vm)?;
}
*self.server_hostname.write() = value.map(|s| s.as_str().to_string());
Ok(())
}
#[pygetset]
fn session(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Return the stored session object if any
let sess = self.session.read().clone();
if let Some(s) = sess {
Ok(s)
} else {
Ok(vm.ctx.none())
}
}
#[pygetset(setter)]
fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// Validate that value is an SSLSession
if !value.is(vm.ctx.types.none_type) {
// Try to downcast to SSLSession to validate
let _ = value
.downcast_ref::<PySSLSession>()
.ok_or_else(|| vm.new_type_error("Value is not a SSLSession."))?;
}
// Check if this is a client socket
if self.server_side {
return Err(vm.new_value_error("Cannot set session for server-side SSLSocket"));
}
// Check if handshake is already done
if *self.handshake_done.lock() {
return Err(vm.new_value_error("Cannot set session after handshake."));
}
// Store the session for potential use during handshake
*self.session.write() = if value.is(vm.ctx.types.none_type) {
None
} else {
Some(value)
};
Ok(())
}
#[pygetset]
fn session_reused(&self) -> bool {
// Return the tracked session reuse status
*self.session_was_reused.lock()
}
#[pymethod]
fn compression(&self) -> Option<&'static str> {
// rustls doesn't support compression
None
}
#[pymethod]
fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult<Option<PyListRef>> {
// Get peer certificates from the connection
let conn_guard = self.connection.lock();
let conn = conn_guard
.as_ref()
.ok_or_else(|| vm.new_value_error("Handshake not completed"))?;
let certs = conn.peer_certificates();
let Some(certs) = certs else {
return Ok(None);
};
// Convert to list of Certificate objects
let cert_list: Vec<PyObjectRef> = certs
.iter()
.map(|cert_der| {
let cert_bytes = cert_der.as_ref().to_vec();
PySSLCertificate {
der_bytes: cert_bytes,
}
.into_ref(&vm.ctx)
.into()
})
.collect();
Ok(Some(vm.ctx.new_list(cert_list)))
}
#[pymethod]
fn get_verified_chain(&self, vm: &VirtualMachine) -> PyResult<Option<PyListRef>> {
// Get peer certificates (what peer sent during handshake)
let conn_guard = self.connection.lock();
let Some(ref conn) = *conn_guard else {
return Ok(None);
};
let peer_certs = conn.peer_certificates();
let Some(peer_certs_slice) = peer_certs else {
return Ok(None);
};
// Build the verified chain using cert module
let ctx_guard = self.context.read();
let ca_certs_der = ctx_guard.ca_certs_der.read();
let chain_der = cert::build_verified_chain(peer_certs_slice, &ca_certs_der);
// Convert DER chain to Python list of Certificate objects
let cert_list: Vec<PyObjectRef> = chain_der
.into_iter()
.map(|der_bytes| PySSLCertificate { der_bytes }.into_ref(&vm.ctx).into())
.collect();
Ok(Some(vm.ctx.new_list(cert_list)))
}
#[pymethod]
fn shutdown(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Check current shutdown state
let current_state = *self.shutdown_state.lock();
// If already completed, return immediately
if current_state == ShutdownState::Completed {
if self.is_bio_mode() {
return Ok(vm.ctx.none());
}
return Ok(self.sock.clone());
}
// Get connection
let mut conn_guard = self.connection.lock();
let conn = conn_guard
.as_mut()
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
let is_bio = self.is_bio_mode();
// Step 1: Send our close_notify if not already sent
if current_state == ShutdownState::NotStarted {
// First, flush ALL pending TLS data BEFORE sending close_notify
// This is CRITICAL - close_notify must come AFTER all application data
// Otherwise data loss occurs when peer receives close_notify first
// Step 1a: Flush any pending TLS records from rustls internal buffer
// This ensures all application data is converted to TLS records
while conn.wants_write() {
let mut buf = Vec::new();
conn.write_tls(&mut buf)
.map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?;
if !buf.is_empty() {
self.send_tls_output(buf, vm)?;
}
}
// Step 1b: Flush pending_tls_output buffer to socket
if !is_bio {
// Socket mode: blocking flush to ensure data order
// Must complete before sending close_notify
self.blocking_flush_all_pending(vm)?;
} else {
// BIO mode: non-blocking flush (caller handles pending data)
let _ = self.flush_pending_tls_output(vm, None);
}
conn.send_close_notify();
// Write close_notify to outgoing buffer/BIO
self.write_pending_tls(conn, vm)?;
// Ensure close_notify and any pending TLS data are flushed
if !is_bio {
self.flush_pending_tls_output(vm, None)?;
}
// Update state
*self.shutdown_state.lock() = ShutdownState::SentCloseNotify;
}
// Step 2: Try to read and process peer's close_notify
// First check if we already have peer's close_notify
// This can happen if it was received during a previous read() call
let mut peer_closed = self.check_peer_closed(conn, vm)?;
// If peer hasn't closed yet, try to read from socket
if !peer_closed {
// Check socket timeout mode
let timeout_mode = if !is_bio {
// Get socket timeout
match self.sock.get_attr("gettimeout", vm) {
Ok(method) => match method.call((), vm) {
Ok(timeout) => {
if vm.is_none(&timeout) {
// timeout=None means blocking
Some(None)
} else if let Ok(t) = timeout.try_float(vm).map(|f| f.to_f64()) {
if t == 0.0 {
// timeout=0 means non-blocking
Some(Some(0.0))
} else {
// timeout>0 means timeout mode
Some(Some(t))
}
} else {
None
}
}
Err(_) => None,
},
Err(_) => None,
}
} else {
None // BIO mode
};
if is_bio {
// In BIO mode: non-blocking read attempt
if self.try_read_close_notify(conn, vm)? {
peer_closed = true;
}
} else if let Some(timeout) = timeout_mode {
match timeout {
Some(0.0) => {
// Non-blocking: return immediately after sending close_notify.
// Don't wait for peer's close_notify to avoid blocking.
drop(conn_guard);
// Best-effort flush; WouldBlock is expected in non-blocking mode.
// Other errors indicate close_notify may not have been sent,
// but we still complete shutdown to avoid inconsistent state.
let _ = self.flush_pending_tls_output(vm, None);
*self.shutdown_state.lock() = ShutdownState::Completed;
*self.connection.lock() = None;
return Ok(self.sock.clone());
}
_ => {
// Blocking or timeout mode: wait for peer's close_notify.
// This is proper TLS shutdown - we should receive peer's
// close_notify before closing the connection.
drop(conn_guard);
// Flush our close_notify first
if timeout.is_none() {
self.blocking_flush_all_pending(vm)?;
} else {
self.flush_pending_tls_output(vm, None)?;
}
// Calculate deadline for timeout mode
let deadline = timeout.map(|t| {
std::time::Instant::now() + core::time::Duration::from_secs_f64(t)
});
// Wait for peer's close_notify
loop {
// Re-acquire connection lock for each iteration
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(c) => c,
None => break, // Connection already closed
};
// Check if peer already sent close_notify
if self.check_peer_closed(conn, vm)? {
break;
}
drop(conn_guard);
// Check timeout
let remaining_timeout = if let Some(dl) = deadline {
let now = std::time::Instant::now();
if now >= dl {
// Timeout reached - raise TimeoutError
return Err(vm.new_exception_msg(
vm.ctx.exceptions.timeout_error.to_owned(),
"The read operation timed out".to_owned(),
));
}
Some(dl - now)
} else {
None // Blocking mode: no timeout
};
// Wait for socket to be readable
let timed_out = self.sock_wait_for_io_with_timeout(
SelectKind::Read,
remaining_timeout,
vm,
)?;
if timed_out {
// Timeout waiting for peer's close_notify
// Raise TimeoutError
return Err(vm.new_exception_msg(
vm.ctx.exceptions.timeout_error.to_owned(),
"The read operation timed out".to_owned(),
));
}
// Try to read data from socket
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(c) => c,
None => break,
};
// Read and process any incoming TLS data
match self.try_read_close_notify(conn, vm) {
Ok(closed) => {
if closed {
break;
}
// Check again after processing
if self.check_peer_closed(conn, vm)? {
break;
}
}
Err(_) => {
// Socket error - peer likely closed connection
break;
}
}
}
// Shutdown complete
*self.shutdown_state.lock() = ShutdownState::Completed;
*self.connection.lock() = None;
return Ok(self.sock.clone());
}
}
}
// Step 3: Check again if peer has sent close_notify (non-blocking/BIO mode only)
if !peer_closed {
peer_closed = self.check_peer_closed(conn, vm)?;
}
}
drop(conn_guard); // Release lock before returning
if !peer_closed {
// Still waiting for peer's close-notify
// Raise SSLWantReadError to signal app needs to transfer data
// This is correct for non-blocking sockets and BIO mode
return Err(create_ssl_want_read_error(vm).upcast());
}
// Both close-notify exchanged, shutdown complete
*self.shutdown_state.lock() = ShutdownState::Completed;
if is_bio {
return Ok(vm.ctx.none());
}
Ok(self.sock.clone())
}
// Helper: Write all pending TLS data (including close_notify) to outgoing buffer/BIO
fn write_pending_tls(&self, conn: &mut TlsConnection, vm: &VirtualMachine) -> PyResult<()> {
// First, flush any previously pending TLS output
// Must succeed before sending new data to maintain order
self.flush_pending_tls_output(vm, None)?;
loop {
if !conn.wants_write() {
break;
}
let mut buf = vec![0u8; SSL3_RT_MAX_PLAIN_LENGTH];
let written = conn
.write_tls(&mut buf.as_mut_slice())
.map_err(|e| vm.new_os_error(format!("TLS write failed: {e}")))?;
if written == 0 {
break;
}
// Send TLS data, saving unsent bytes to pending buffer if needed
self.send_tls_output(buf[..written].to_vec(), vm)?;
}
Ok(())
}
// Helper: Try to read incoming data from socket/BIO
// Returns true if peer closed connection (with or without close_notify)
fn try_read_close_notify(
&self,
conn: &mut TlsConnection,
vm: &VirtualMachine,
) -> PyResult<bool> {
// In socket mode, peek first to avoid consuming post-TLS cleartext
// data. During STARTTLS, after close_notify exchange, the socket
// transitions to cleartext. Without peeking, sock_recv may consume
// cleartext data meant for the application after unwrap().
if self.incoming_bio.is_none() {
return self.try_read_close_notify_socket(conn, vm);
}
// BIO mode: read from incoming BIO
match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) {
Ok(bytes_obj) => {
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
let data = bytes.borrow_buf();
if data.is_empty() {
if let Some(ref bio) = self.incoming_bio {
// BIO mode: check if EOF was signaled via write_eof()
let bio_obj: PyObjectRef = bio.clone().into();
let eof_attr = bio_obj.get_attr("eof", vm)?;
let is_eof = eof_attr.try_to_bool(vm)?;
if !is_eof {
return Ok(false);
}
}
return Ok(true);
}
let data_slice: &[u8] = data.as_ref();
let mut cursor = std::io::Cursor::new(data_slice);
let _ = conn.read_tls(&mut cursor);
let _ = conn.process_new_packets();
Ok(false)
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
Ok(true)
}
}
}
/// Socket-mode close_notify reader that respects TLS record boundaries.
/// Uses MSG_PEEK to inspect data before consuming, preventing accidental
/// consumption of post-TLS cleartext data during STARTTLS transitions.
///
/// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no
/// such knob, so we enforce record-level reads manually via peek.
fn try_read_close_notify_socket(
&self,
conn: &mut TlsConnection,
vm: &VirtualMachine,
) -> PyResult<bool> {
// Peek at the first 5 bytes (TLS record header size)
let peeked_obj = match self.sock_peek(5, vm) {
Ok(obj) => obj,
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
return Ok(true);
}
};
let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?;
let peek_data = peeked.borrow_buf();
if peek_data.is_empty() {
return Ok(true); // EOF
}
// TLS record content types: ChangeCipherSpec(20), Alert(21),
// Handshake(22), ApplicationData(23)
let content_type = peek_data[0];
if !(20..=23).contains(&content_type) {
// Not a TLS record - post-TLS cleartext data.
// Peer has completed TLS shutdown; don't consume this data.
return Ok(true);
}
// Determine how many bytes to read for exactly one TLS record
let recv_size = if peek_data.len() >= 5 {
let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize;
5 + record_length
} else {
// Partial header available - read just these bytes for now
peek_data.len()
};
drop(peek_data);
drop(peeked);
// Now consume exactly one TLS record from the socket
match self.sock_recv(recv_size, vm) {
Ok(bytes_obj) => {
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
let data = bytes.borrow_buf();
if data.is_empty() {
return Ok(true);
}
let data_slice: &[u8] = data.as_ref();
let mut cursor = std::io::Cursor::new(data_slice);
let _ = conn.read_tls(&mut cursor);
let _ = conn.process_new_packets();
Ok(false)
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
Ok(true)
}
}
}
// Helper: Check if peer has sent close_notify
fn check_peer_closed(
&self,
conn: &mut TlsConnection,
vm: &VirtualMachine,
) -> PyResult<bool> {
// Process any remaining packets and check peer_has_closed
let io_state = conn
.process_new_packets()
.map_err(|e| vm.new_os_error(format!("Failed to process packets: {e}")))?;
Ok(io_state.peer_has_closed())
}
#[pymethod]
fn shared_ciphers(&self, vm: &VirtualMachine) -> Option<PyListRef> {
// Return None for client-side sockets
if !self.server_side {
return None;
}
// Check if handshake completed
if !*self.handshake_done.lock() {
return None;
}
// Get negotiated cipher suite from rustls
let conn_guard = self.connection.lock();
let conn = conn_guard.as_ref()?;
let suite = conn.negotiated_cipher_suite()?;
// Extract cipher information using unified helper
let cipher_info = extract_cipher_info(&suite);
// Return as list with single tuple (name, version, bits)
let tuple = vm.ctx.new_tuple(vec![
vm.ctx.new_str(cipher_info.name).into(),
vm.ctx.new_str(cipher_info.protocol).into(),
vm.ctx.new_int(cipher_info.bits).into(),
]);
Some(vm.ctx.new_list(vec![tuple.into()]))
}
#[pymethod]
fn verify_client_post_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
// TLS 1.3 post-handshake authentication
// This is only valid for server-side TLS 1.3 connections
// Check if this is a server-side socket
if !self.server_side {
return Err(vm.new_value_error(
"Cannot perform post-handshake authentication on client-side socket",
));
}
// Check if handshake has been completed
if !*self.handshake_done.lock() {
return Err(vm.new_value_error(
"Handshake must be completed before post-handshake authentication",
));
}
// Check connection exists and protocol version
let conn_guard = self.connection.lock();
if let Some(conn) = conn_guard.as_ref() {
let version = match conn {
TlsConnection::Client(_) => {
return Err(vm.new_value_error(
"Post-handshake authentication requires server socket",
));
}
TlsConnection::Server(server) => server.protocol_version(),
};
// Post-handshake auth is only available in TLS 1.3
if version != Some(rustls::ProtocolVersion::TLSv1_3) {
// Get SSLError class from ssl module (not _ssl)
// ssl.py imports _ssl.SSLError as ssl.SSLError
let ssl_mod = vm.import("ssl", 0)?;
let ssl_error_class = ssl_mod.get_attr("SSLError", vm)?;
// Create SSLError instance with message containing WRONG_SSL_VERSION
let msg = "[SSL: WRONG_SSL_VERSION] wrong ssl version";
let args = vm.ctx.new_tuple(vec![vm.ctx.new_str(msg).into()]);
let exc = ssl_error_class.call((args,), vm)?;
return Err(exc
.downcast()
.map_err(|_| vm.new_type_error("Failed to create SSLError"))?);
}
} else {
return Err(vm.new_value_error("No SSL connection established"));
}
// rustls doesn't provide an API for post-handshake authentication.
// The rustls TLS library does not support requesting client certificates
// after the initial handshake is completed.
// Raise SSLError instead of NotImplementedError for compatibility
Err(vm
.new_os_subtype_error(
PySSLError::class(&vm.ctx).to_owned(),
None,
"Post-handshake authentication is not supported by the rustls backend. \
The rustls TLS library does not provide an API to request client certificates \
after the initial handshake. Consider requesting the client certificate \
during the initial handshake by setting the appropriate verify_mode before \
calling do_handshake().",
)
.upcast())
}
#[pymethod]
fn get_channel_binding(
&self,
cb_type: OptionalArg<PyStrRef>,
vm: &VirtualMachine,
) -> PyResult<Option<PyBytesRef>> {
let cb_type_str = cb_type.as_ref().map_or("tls-unique", |s| s.as_str());
// rustls doesn't support channel binding (tls-unique, tls-server-end-point, etc.)
// This is because:
// 1. tls-unique requires access to TLS Finished messages, which rustls doesn't expose
// 2. tls-server-end-point requires the server certificate, which we don't track here
// 3. TLS 1.3 deprecated tls-unique anyway
//
// For compatibility, we'll return None (no channel binding available)
// rather than raising an error
if cb_type_str != "tls-unique" {
return Err(vm.new_value_error(format!(
"Unsupported channel binding type '{cb_type_str}'",
)));
}
// Return None to indicate channel binding is not available
// This matches the behavior when the handshake hasn't completed yet
Ok(None)
}
}
impl Constructor for PySSLSocket {
type Args = ();
fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
Err(vm.new_type_error(
"Cannot directly instantiate SSLSocket, use SSLContext.wrap_socket()",
))
}
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
unimplemented!("use slot_new")
}
}
// Clean up SSL socket resources on drop
impl Drop for PySSLSocket {
fn drop(&mut self) {
// Only clear connection state.
// Do NOT clear pending_tls_output - it may contain data that hasn't
// been flushed to the socket yet. SSLSocket._real_close() in Python
// doesn't call shutdown(), so when the socket is closed, pending TLS
// data would be lost if we clear it here.
// All fields (Vec, primitives) are automatically freed when the
// struct is dropped, so explicit clearing is unnecessary.
let _ = self.connection.lock().take();
}
}
// MemoryBIO - provides in-memory buffer for SSL/TLS I/O
#[pyattr]
#[pyclass(name = "MemoryBIO", module = "ssl")]
#[derive(Debug, PyPayload)]
struct PyMemoryBIO {
// Internal buffer
buffer: PyMutex<Vec<u8>>,
// EOF flag
eof: PyRwLock<bool>,
}
#[pyclass(with(Constructor), flags(BASETYPE))]
impl PyMemoryBIO {
#[pymethod]
fn read(&self, len: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult<PyBytesRef> {
let mut buffer = self.buffer.lock();
if buffer.is_empty() && *self.eof.read() {
// Return empty bytes at EOF
return Ok(vm.ctx.new_bytes(vec![]));
}
let read_len = match len {
OptionalArg::Present(n) if n >= 0 => n as usize,
OptionalArg::Present(n) => {
return Err(vm.new_value_error(format!("negative read length: {n}")));
}
OptionalArg::Missing => buffer.len(), // Read all available
};
let actual_len = read_len.min(buffer.len());
let data = buffer.drain(..actual_len).collect::<Vec<u8>>();
Ok(vm.ctx.new_bytes(data))
}
#[pymethod]
fn write(&self, buf: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
// Check if it's a memoryview and if it's contiguous
if let Ok(mem_view) = buf.get_attr("c_contiguous", vm) {
// It's a memoryview, check if contiguous
let is_contiguous: bool = mem_view.try_to_bool(vm)?;
if !is_contiguous {
return Err(vm.new_exception_msg(
vm.ctx.exceptions.buffer_error.to_owned(),
"non-contiguous buffer is not supported".to_owned(),
));
}
}
// Convert to bytes-like object
let bytes_like = ArgBytesLike::try_from_object(vm, buf)?;
let data = bytes_like.borrow_buf();
let len = data.len();
let mut buffer = self.buffer.lock();
buffer.extend_from_slice(&data);
Ok(len)
}
#[pymethod]
fn write_eof(&self, _vm: &VirtualMachine) -> PyResult<()> {
*self.eof.write() = true;
Ok(())
}
#[pygetset]
fn pending(&self) -> i32 {
self.buffer.lock().len() as i32
}
#[pygetset]
fn eof(&self) -> bool {
// EOF is true only when buffer is empty AND write_eof has been called
let pending = self.buffer.lock().len();
pending == 0 && *self.eof.read()
}
}
impl Representable for PyMemoryBIO {
#[inline]
fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
Ok("<MemoryBIO>".to_owned())
}
}
impl Constructor for PyMemoryBIO {
type Args = ();
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
Ok(PyMemoryBIO {
buffer: PyMutex::new(Vec::new()),
eof: PyRwLock::new(false),
})
}
}
// SSLSession - represents a cached SSL session
// NOTE: This is an EMULATION - actual session data is managed by Rustls internally
#[pyattr]
#[pyclass(name = "SSLSession", module = "ssl")]
#[derive(Debug, PyPayload)]
struct PySSLSession {
// Session data - serialized rustls session (EMULATED - kept empty)
session_data: Vec<u8>,
// Session ID - synthetic ID generated from metadata (NOT actual TLS session ID)
#[allow(dead_code)]
session_id: Vec<u8>,
// Session metadata
creation_time: std::time::SystemTime,
// Lifetime in seconds (default 7200 = 2 hours)
lifetime: u64,
}
#[pyclass(flags(BASETYPE))]
impl PySSLSession {
#[pygetset]
fn time(&self) -> i64 {
// Return session creation time as Unix timestamp
self.creation_time
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
}
#[pygetset]
fn timeout(&self) -> i64 {
// Return session timeout/lifetime in seconds
self.lifetime as i64
}
#[pygetset]
fn ticket_lifetime_hint(&self) -> i64 {
// Return ticket lifetime hint (same as timeout for rustls)
self.lifetime as i64
}
#[pygetset]
fn id(&self, vm: &VirtualMachine) -> PyBytesRef {
// Return session ID (hash of session data for uniqueness)
use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
self.session_data.hash(&mut hasher);
let hash = hasher.finish();
// Convert hash to bytes
vm.ctx.new_bytes(hash.to_be_bytes().to_vec())
}
#[pygetset]
fn has_ticket(&self) -> bool {
// For rustls, if we have session data, we have a ticket
!self.session_data.is_empty()
}
}
impl Representable for PySSLSession {
#[inline]
fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
Ok("<SSLSession>".to_owned())
}
}
// Helper functions
// OID module already imported at top of _ssl module
#[derive(FromArgs)]
struct Txt2ObjArgs {
txt: PyStrRef,
#[pyarg(named, optional)]
name: OptionalArg<bool>,
}
#[pyfunction]
fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let txt = args.txt.as_str();
let name = args.name.unwrap_or(false);
// If name=False (default), only accept OID strings
// If name=True, accept both names and OID strings
let entry = if txt
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
// Looks like an OID string (starts with digit)
oid::find_by_oid_string(txt)
} else if name {
// name=True: allow shortname/longname lookup
oid::find_by_name(txt)
} else {
// name=False: only OID strings allowed, not names
None
};
let entry = entry.ok_or_else(|| vm.new_value_error(format!("unknown object '{txt}'")))?;
// Return tuple: (nid, shortname, longname, oid)
Ok(vm
.new_tuple((
vm.ctx.new_int(entry.nid),
vm.ctx.new_str(entry.short_name),
vm.ctx.new_str(entry.long_name),
vm.ctx.new_str(entry.oid_string()),
))
.into())
}
#[pyfunction]
fn nid2obj(nid: i32, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let entry = oid::find_by_nid(nid)
.ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}")))?;
// Return tuple: (nid, shortname, longname, oid)
Ok(vm
.new_tuple((
vm.ctx.new_int(entry.nid),
vm.ctx.new_str(entry.short_name),
vm.ctx.new_str(entry.long_name),
vm.ctx.new_str(entry.oid_string()),
))
.into())
}
#[pyfunction]
fn get_default_verify_paths(vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Return default certificate paths as a tuple
// Lib/ssl.py expects: (openssl_cafile_env, openssl_cafile, openssl_capath_env, openssl_capath)
// parts[0] = environment variable name for cafile
// parts[1] = default cafile path
// parts[2] = environment variable name for capath
// parts[3] = default capath path
// Common default paths for different platforms
// These match the first candidates that rustls-native-certs/openssl-probe checks
#[cfg(target_os = "macos")]
let (default_cafile, default_capath) = {
// macOS primarily uses Keychain API, but provides fallback paths
// for compatibility and when Keychain access fails
(Some("/etc/ssl/cert.pem"), Some("/etc/ssl/certs"))
};
#[cfg(target_os = "linux")]
let (default_cafile, default_capath) = {
// Linux: matches openssl-probe's first candidate (/etc/ssl/cert.pem)
// openssl-probe checks multiple locations at runtime, but we return
// OpenSSL's compile-time default
(Some("/etc/ssl/cert.pem"), Some("/etc/ssl/certs"))
};
#[cfg(windows)]
let (default_cafile, default_capath) = {
// Windows uses certificate store, not file paths
// Return empty strings to avoid None being passed to os.path.isfile()
(Some(""), Some(""))
};
#[cfg(not(any(target_os = "macos", target_os = "linux", windows)))]
let (default_cafile, default_capath): (Option<&str>, Option<&str>) = (None, None);
let tuple = vm.ctx.new_tuple(vec![
vm.ctx.new_str("SSL_CERT_FILE").into(), // openssl_cafile_env
default_cafile
.map(|s| vm.ctx.new_str(s).into())
.unwrap_or_else(|| vm.ctx.none()), // openssl_cafile
vm.ctx.new_str("SSL_CERT_DIR").into(), // openssl_capath_env
default_capath
.map(|s| vm.ctx.new_str(s).into())
.unwrap_or_else(|| vm.ctx.none()), // openssl_capath
]);
Ok(tuple.into())
}
#[pyfunction]
fn RAND_status() -> i32 {
1 // Always have good randomness with aws-lc-rs
}
#[pyfunction]
fn RAND_add(_string: PyObjectRef, _entropy: f64) {
// No-op: aws-lc-rs handles its own entropy
// Accept any type (str, bytes, bytearray)
}
#[pyfunction]
fn RAND_bytes(n: i64, vm: &VirtualMachine) -> PyResult<PyBytesRef> {
use aws_lc_rs::rand::{SecureRandom, SystemRandom};
// Validate n is not negative
if n < 0 {
return Err(vm.new_value_error("num must be positive"));
}
let n_usize = n as usize;
let rng = SystemRandom::new();
let mut buf = vec![0u8; n_usize];
rng.fill(&mut buf)
.map_err(|_| vm.new_os_error("Failed to generate random bytes"))?;
Ok(PyBytesRef::from(vm.ctx.new_bytes(buf)))
}
#[pyfunction]
fn RAND_pseudo_bytes(n: i64, vm: &VirtualMachine) -> PyResult<(PyBytesRef, bool)> {
// In rustls/aws-lc-rs, all random bytes are cryptographically strong
let bytes = RAND_bytes(n, vm)?;
Ok((bytes, true))
}
/// Test helper to decode a certificate from a file path
///
/// This is a simplified wrapper around cert_der_to_dict_helper that handles
/// file reading and PEM/DER auto-detection. Used by test suite.
#[pyfunction]
fn _test_decode_cert(path: PyStrRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
// Read certificate file
let cert_data = std::fs::read(path.as_str()).map_err(|e| {
vm.new_os_error(format!(
"Failed to read certificate file {}: {}",
path.as_str(),
e
))
})?;
// Auto-detect PEM vs DER format
let cert_der = if cert_data
.windows(27)
.any(|w| w == b"-----BEGIN CERTIFICATE-----")
{
// Parse PEM format
let mut cursor = std::io::Cursor::new(&cert_data);
rustls_pemfile::certs(&mut cursor)
.find_map(|r| r.ok())
.ok_or_else(|| vm.new_value_error("No valid certificate found in PEM file"))?
.to_vec()
} else {
// Assume DER format
cert_data
};
// Reuse the comprehensive helper function
cert::cert_der_to_dict_helper(vm, &cert_der)
}
#[pyfunction]
fn DER_cert_to_PEM_cert(der_cert: ArgBytesLike, vm: &VirtualMachine) -> PyResult<PyStrRef> {
let der_bytes = der_cert.borrow_buf();
let bytes_slice: &[u8] = der_bytes.as_ref();
// Use pem-rfc7468 for RFC 7468 compliant PEM encoding
let pem_str = encode_string("CERTIFICATE", LineEnding::LF, bytes_slice)
.map_err(|e| vm.new_value_error(format!("PEM encoding failed: {e}")))?;
Ok(vm.ctx.new_str(pem_str))
}
#[pyfunction]
fn PEM_cert_to_DER_cert(pem_cert: PyStrRef, vm: &VirtualMachine) -> PyResult<PyBytesRef> {
let pem_str = pem_cert.as_str();
// Parse PEM format
let mut cursor = std::io::Cursor::new(pem_str.as_bytes());
let mut certs = rustls_pemfile::certs(&mut cursor);
if let Some(Ok(cert)) = certs.next() {
Ok(vm.ctx.new_bytes(cert.to_vec()))
} else {
Err(vm.new_value_error("Failed to parse PEM certificate"))
}
}
// Windows-specific certificate store enumeration functions
#[cfg(windows)]
#[pyfunction]
fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult<Vec<PyObjectRef>> {
use schannel::{RawPointer, cert_context::ValidUses, cert_store::CertStore};
use windows_sys::Win32::Security::Cryptography;
// Try both Current User and Local Machine stores
let open_fns = [CertStore::open_current_user, CertStore::open_local_machine];
let stores = open_fns
.iter()
.filter_map(|open| open(store_name.as_str()).ok())
.collect::<Vec<_>>();
// If no stores could be opened, raise OSError
if stores.is_empty() {
return Err(vm.new_os_error(format!(
"failed to open certificate store {:?}",
store_name.as_str()
)));
}
let certs = stores.iter().flat_map(|s| s.certs()).map(|c| {
let cert = vm.ctx.new_bytes(c.to_der().to_owned());
let enc_type = unsafe {
let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT;
(*ptr).dwCertEncodingType
};
let enc_type = match enc_type {
Cryptography::X509_ASN_ENCODING => vm.new_pyobj("x509_asn"),
Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj("pkcs_7_asn"),
other => vm.new_pyobj(other),
};
let usage: PyObjectRef = match c.valid_uses() {
Ok(ValidUses::All) => vm.ctx.new_bool(true).into(),
Ok(ValidUses::Oids(oids)) => {
match crate::builtins::PyFrozenSet::from_iter(
vm,
oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()),
) {
Ok(set) => set.into_ref(&vm.ctx).into(),
Err(_) => vm.ctx.new_bool(true).into(),
}
}
Err(_) => vm.ctx.new_bool(true).into(),
};
Ok(vm.new_tuple((cert, enc_type, usage)).into())
});
certs.collect::<PyResult<Vec<_>>>()
}
#[cfg(windows)]
#[pyfunction]
fn enum_crls(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult<Vec<PyObjectRef>> {
use windows_sys::Win32::Security::Cryptography::{
CRL_CONTEXT, CertCloseStore, CertEnumCRLsInStore, CertOpenSystemStoreW,
X509_ASN_ENCODING,
};
let store_name_wide: Vec<u16> = store_name
.as_str()
.encode_utf16()
.chain(core::iter::once(0))
.collect();
// Open system store
let store = unsafe { CertOpenSystemStoreW(0, store_name_wide.as_ptr()) };
if store.is_null() {
return Err(vm.new_os_error(format!(
"failed to open certificate store {:?}",
store_name.as_str()
)));
}
let mut result = Vec::new();
let mut crl_context: *const CRL_CONTEXT = core::ptr::null();
loop {
crl_context = unsafe { CertEnumCRLsInStore(store, crl_context) };
if crl_context.is_null() {
break;
}
let crl = unsafe { &*crl_context };
let crl_bytes =
unsafe { core::slice::from_raw_parts(crl.pbCrlEncoded, crl.cbCrlEncoded as usize) };
let enc_type = if crl.dwCertEncodingType == X509_ASN_ENCODING {
vm.new_pyobj("x509_asn")
} else {
vm.new_pyobj(crl.dwCertEncodingType)
};
result.push(
vm.new_tuple((vm.ctx.new_bytes(crl_bytes.to_vec()), enc_type))
.into(),
);
}
unsafe { CertCloseStore(store, 0) };
Ok(result)
}
// Certificate type for SSL module (pure Rust implementation)
#[pyattr]
#[pyclass(module = "_ssl", name = "Certificate")]
#[derive(Debug, PyPayload)]
pub struct PySSLCertificate {
// Store the raw DER bytes
der_bytes: Vec<u8>,
}
impl PySSLCertificate {
// Parse the certificate lazily
fn parse(&self) -> Result<x509_parser::certificate::X509Certificate<'_>, String> {
match x509_parser::parse_x509_certificate(&self.der_bytes) {
Ok((_, cert)) => Ok(cert),
Err(e) => Err(format!("Failed to parse certificate: {e}")),
}
}
}
#[pyclass(with(Comparable, Hashable, Representable))]
impl PySSLCertificate {
#[pymethod]
fn public_bytes(
&self,
format: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<PyObjectRef> {
let format = format.unwrap_or(ENCODING_PEM);
match format {
x if x == ENCODING_DER => {
// Return DER bytes directly
Ok(vm.ctx.new_bytes(self.der_bytes.clone()).into())
}
x if x == ENCODING_PEM => {
// Convert DER to PEM using RFC 7468 compliant encoding
let pem_str = encode_string("CERTIFICATE", LineEnding::LF, &self.der_bytes)
.map_err(|e| vm.new_value_error(format!("PEM encoding failed: {e}")))?;
Ok(vm.ctx.new_str(pem_str).into())
}
_ => Err(vm.new_value_error("Unsupported format")),
}
}
#[pymethod]
fn get_info(&self, vm: &VirtualMachine) -> PyResult {
let cert = self.parse().map_err(|e| vm.new_value_error(e))?;
cert::cert_to_dict(vm, &cert)
}
}
// Implement Comparable trait for PySSLCertificate
impl Comparable for PySSLCertificate {
fn cmp(
zelf: &Py<Self>,
other: &PyObject,
op: PyComparisonOp,
_vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
op.eq_only(|| {
if let Some(other_cert) = other.downcast_ref::<Self>() {
Ok((zelf.der_bytes == other_cert.der_bytes).into())
} else {
Ok(PyComparisonValue::NotImplemented)
}
})
}
}
// Implement Hashable trait for PySSLCertificate
impl Hashable for PySSLCertificate {
fn hash(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyHash> {
use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
zelf.der_bytes.hash(&mut hasher);
Ok(hasher.finish() as PyHash)
}
}
// Implement Representable trait for PySSLCertificate
impl Representable for PySSLCertificate {
#[inline]
fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
// Try to parse and show subject
match zelf.parse() {
Ok(cert) => {
let subject = cert.subject();
// Get CN if available
let cn = subject
.iter_common_name()
.next()
.and_then(|attr| attr.as_str().ok())
.unwrap_or("Unknown");
Ok(format!("<Certificate(subject=CN={cn})>"))
}
Err(_) => Ok("<Certificate(invalid)>".to_owned()),
}
}
}
}