ssl.{SSLSession,MemoryBIO} (#6209)

* SSLSession

* get_unverified_chain

* SSL MemoryBIO
This commit is contained in:
Jeong, YunWon
2025-10-23 18:37:40 +09:00
committed by GitHub
parent 2463bdff0e
commit 3ec905e08a
2 changed files with 398 additions and 10 deletions

4
Lib/ssl.py vendored
View File

@@ -98,7 +98,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
from _ssl import _SSLContext#, MemoryBIO, SSLSession
from _ssl import _SSLContext, SSLSession, MemoryBIO
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError, SSLCertVerificationError
@@ -114,7 +114,7 @@ except ImportError:
from _ssl import (
HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK
)
from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION

View File

@@ -38,15 +38,16 @@ mod _ssl {
},
socket::{self, PySocket},
vm::{
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
class_or_notimplemented,
convert::{ToPyException, ToPyObject},
exceptions,
function::{
ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath,
OptionalArg,
OptionalArg, PyComparisonValue,
},
types::Constructor,
types::{Comparable, Constructor, PyComparisonOp},
utils::ToCString,
},
};
@@ -162,6 +163,8 @@ mod _ssl {
const HAS_TLSv1_2: bool = true;
#[pyattr]
const HAS_TLSv1_3: bool = cfg!(ossl111);
#[pyattr]
const HAS_PSK: bool = true;
// the openssl version from the API headers
@@ -816,17 +819,46 @@ mod _ssl {
let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone()))
.map_err(|e| convert_openssl_error(vm, e))?;
// TODO: use this
let _ = args.session;
Ok(PySslSocket {
let py_ssl_socket = PySslSocket {
ctx: zelf,
stream: PyRwLock::new(stream),
socket_type,
server_hostname: args.server_hostname,
owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?),
})
};
// Set session if provided
if let Some(session) = args.session
&& !vm.is_none(&session)
{
py_ssl_socket.set_session(session, vm)?;
}
Ok(py_ssl_socket)
}
#[pymethod]
fn _wrap_bio(_zelf: PyRef<Self>, _args: WrapBioArgs, vm: &VirtualMachine) -> PyResult {
// TODO: Implement BIO-based SSL wrapping
// This requires refactoring PySslSocket to support both socket and BIO modes
Err(vm.new_not_implemented_error(
"_wrap_bio is not yet implemented in RustPython".to_owned(),
))
}
}
#[derive(FromArgs)]
#[allow(dead_code)] // Fields will be used when _wrap_bio is fully implemented
struct WrapBioArgs {
incoming: PyRef<PySslMemoryBio>,
outgoing: PyRef<PySslMemoryBio>,
server_side: bool,
#[pyarg(any, default)]
server_hostname: Option<PyStrRef>,
#[pyarg(named, default)]
owner: Option<PyObjectRef>,
#[pyarg(named, default)]
session: Option<PyObjectRef>,
}
#[derive(FromArgs)]
@@ -996,6 +1028,19 @@ mod _ssl {
.transpose()
}
#[pymethod]
fn get_unverified_chain(&self, vm: &VirtualMachine) -> Option<PyObjectRef> {
let stream = self.stream.read();
let chain = stream.ssl().peer_cert_chain()?;
let certs: Vec<PyObjectRef> = chain
.iter()
.filter_map(|cert| cert.to_der().ok().map(|der| vm.ctx.new_bytes(der).into()))
.collect();
Some(vm.ctx.new_list(certs).into())
}
#[pymethod]
fn version(&self) -> Option<&'static str> {
let v = self.stream.read().ssl().version_str();
@@ -1103,6 +1148,73 @@ mod _ssl {
}
}
#[pygetset]
fn session(&self, _vm: &VirtualMachine) -> PyResult<Option<PySslSession>> {
let stream = self.stream.read();
unsafe {
let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr());
if session_ptr.is_null() {
Ok(None)
} else {
// Increment reference count since SSL_get_session returns a borrowed reference
#[cfg(ossl110)]
let _session = sys::SSL_SESSION_up_ref(session_ptr);
Ok(Some(PySslSession {
session: session_ptr,
ctx: self.ctx.clone(),
}))
}
}
}
#[pygetset(setter)]
fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// Check if value is SSLSession type
let session = value
.downcast_ref::<PySslSession>()
.ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?;
// Check if session refers to the same SSLContext
if !std::ptr::eq(
self.ctx.ctx.read().as_ptr(),
session.ctx.ctx.read().as_ptr(),
) {
return Err(
vm.new_value_error("Session refers to a different SSLContext.".to_owned())
);
}
// Check if this is a client socket
if self.socket_type != SslServerOrClient::Client {
return Err(
vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned())
);
}
// Check if handshake is not finished
let stream = self.stream.read();
unsafe {
if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 {
return Err(
vm.new_value_error("Cannot set session after handshake.".to_owned())
);
}
if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 {
return Err(convert_openssl_error(vm, ErrorStack::get()));
}
}
Ok(())
}
#[pygetset]
fn session_reused(&self) -> bool {
let stream = self.stream.read();
unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 }
}
#[pymethod]
fn read(
&self,
@@ -1164,6 +1276,282 @@ mod _ssl {
}
}
#[pyattr]
#[pyclass(module = "ssl", name = "SSLSession")]
#[derive(PyPayload)]
struct PySslSession {
session: *mut sys::SSL_SESSION,
ctx: PyRef<PySslContext>,
}
impl fmt::Debug for PySslSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("SSLSession")
}
}
impl Drop for PySslSession {
fn drop(&mut self) {
if !self.session.is_null() {
unsafe {
sys::SSL_SESSION_free(self.session);
}
}
}
}
unsafe impl Send for PySslSession {}
unsafe impl Sync for PySslSession {}
impl Comparable for PySslSession {
fn cmp(
zelf: &Py<Self>,
other: &crate::vm::PyObject,
op: PyComparisonOp,
_vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
let other = class_or_notimplemented!(Self, other);
if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) {
return Ok(PyComparisonValue::NotImplemented);
}
let mut eq = unsafe {
let mut self_len: libc::c_uint = 0;
let mut other_len: libc::c_uint = 0;
let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len);
let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len);
if self_len != other_len {
false
} else {
let self_slice = std::slice::from_raw_parts(self_id, self_len as usize);
let other_slice = std::slice::from_raw_parts(other_id, other_len as usize);
self_slice == other_slice
}
};
if matches!(op, PyComparisonOp::Ne) {
eq = !eq;
}
Ok(PyComparisonValue::Implemented(eq))
}
}
#[pyattr]
#[pyclass(module = "ssl", name = "MemoryBIO")]
#[derive(PyPayload)]
struct PySslMemoryBio {
bio: *mut sys::BIO,
eof_written: AtomicCell<bool>,
}
impl fmt::Debug for PySslMemoryBio {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("MemoryBIO")
}
}
impl Drop for PySslMemoryBio {
fn drop(&mut self) {
if !self.bio.is_null() {
unsafe {
sys::BIO_free_all(self.bio);
}
}
}
}
unsafe impl Send for PySslMemoryBio {}
unsafe impl Sync for PySslMemoryBio {}
// OpenSSL BIO helper functions
// These are typically macros in OpenSSL, implemented via BIO_ctrl
const BIO_CTRL_PENDING: libc::c_int = 10;
const BIO_CTRL_SET_EOF: libc::c_int = 2;
#[allow(non_snake_case)]
unsafe fn BIO_ctrl_pending(bio: *mut sys::BIO) -> usize {
unsafe { sys::BIO_ctrl(bio, BIO_CTRL_PENDING, 0, std::ptr::null_mut()) as usize }
}
#[allow(non_snake_case)]
unsafe fn BIO_set_mem_eof_return(bio: *mut sys::BIO, eof: libc::c_int) -> libc::c_int {
unsafe {
sys::BIO_ctrl(
bio,
BIO_CTRL_SET_EOF,
eof as libc::c_long,
std::ptr::null_mut(),
) as libc::c_int
}
}
#[allow(non_snake_case)]
unsafe fn BIO_clear_retry_flags(bio: *mut sys::BIO) {
unsafe {
sys::BIO_clear_flags(bio, sys::BIO_FLAGS_RWS | sys::BIO_FLAGS_SHOULD_RETRY);
}
}
impl Constructor for PySslMemoryBio {
type Args = ();
fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult {
unsafe {
let bio = sys::BIO_new(sys::BIO_s_mem());
if bio.is_null() {
return Err(vm.new_memory_error("failed to allocate BIO".to_owned()));
}
sys::BIO_set_retry_read(bio);
BIO_set_mem_eof_return(bio, -1);
PySslMemoryBio {
bio,
eof_written: AtomicCell::new(false),
}
.into_ref_with_type(vm, cls)
.map(Into::into)
}
}
}
#[pyclass(with(Constructor))]
impl PySslMemoryBio {
#[pygetset]
fn pending(&self) -> usize {
unsafe { BIO_ctrl_pending(self.bio) }
}
#[pygetset]
fn eof(&self) -> bool {
let pending = unsafe { BIO_ctrl_pending(self.bio) };
pending == 0 && self.eof_written.load()
}
#[pymethod]
fn read(&self, size: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
unsafe {
let avail = BIO_ctrl_pending(self.bio).min(i32::MAX as usize) as i32;
let len = size.unwrap_or(-1);
let len = if len < 0 || len > avail { avail } else { len };
if len == 0 {
return Ok(Vec::new());
}
let mut buf = vec![0u8; len as usize];
let nbytes = sys::BIO_read(self.bio, buf.as_mut_ptr() as *mut _, len);
if nbytes < 0 {
return Err(convert_openssl_error(vm, ErrorStack::get()));
}
buf.truncate(nbytes as usize);
Ok(buf)
}
}
#[pymethod]
fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<i32> {
if self.eof_written.load() {
return Err(vm.new_exception_msg(
ssl_error(vm),
"cannot write() after write_eof()".to_owned(),
));
}
data.with_ref(|buf| unsafe {
if buf.len() > i32::MAX as usize {
return Err(
vm.new_overflow_error(format!("string longer than {} bytes", i32::MAX))
);
}
let nbytes = sys::BIO_write(self.bio, buf.as_ptr() as *const _, buf.len() as i32);
if nbytes < 0 {
return Err(convert_openssl_error(vm, ErrorStack::get()));
}
Ok(nbytes)
})
}
#[pymethod]
fn write_eof(&self) {
self.eof_written.store(true);
unsafe {
BIO_clear_retry_flags(self.bio);
BIO_set_mem_eof_return(self.bio, 0);
}
}
}
#[pyclass(with(Comparable))]
impl PySslSession {
#[pygetset]
fn time(&self) -> i64 {
unsafe {
#[cfg(ossl330)]
{
sys::SSL_SESSION_get_time(self.session) as i64
}
#[cfg(not(ossl330))]
{
sys::SSL_SESSION_get_time(self.session) as i64
}
}
}
#[pygetset]
fn timeout(&self) -> i64 {
unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 }
}
#[pygetset]
fn ticket_lifetime_hint(&self) -> u64 {
// SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL
// Return 0 as default if not available
#[cfg(ossl110)]
{
// For now, return 0 as this function may not be in openssl-sys
let _ = self.session;
0
}
#[cfg(not(ossl110))]
{
let _ = self.session;
0
}
}
#[pygetset]
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
unsafe {
let mut len: libc::c_uint = 0;
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
vm.ctx.new_bytes(id_slice.to_vec()).into()
}
}
#[pygetset]
fn has_ticket(&self) -> bool {
// SSL_SESSION_has_ticket may not be available in older OpenSSL
// Return false as default
#[cfg(ossl110)]
{
// For now, return false as this function may not be in openssl-sys
let _ = self.session;
false
}
#[cfg(not(ossl110))]
{
let _ = self.session;
false
}
}
}
#[track_caller]
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
let cls = ssl_error(vm);