mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
ssl.{SSLSession,MemoryBIO} (#6209)
* SSLSession * get_unverified_chain * SSL MemoryBIO
This commit is contained in:
4
Lib/ssl.py
vendored
4
Lib/ssl.py
vendored
@@ -98,7 +98,7 @@ from enum import Enum as _Enum, IntEnum as _IntEnum, IntFlag as _IntFlag
|
||||
import _ssl # if we can't import it, let the error propagate
|
||||
|
||||
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
|
||||
from _ssl import _SSLContext#, MemoryBIO, SSLSession
|
||||
from _ssl import _SSLContext, SSLSession, MemoryBIO
|
||||
from _ssl import (
|
||||
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
|
||||
SSLSyscallError, SSLEOFError, SSLCertVerificationError
|
||||
@@ -114,7 +114,7 @@ except ImportError:
|
||||
|
||||
from _ssl import (
|
||||
HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
|
||||
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
|
||||
HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3, HAS_PSK
|
||||
)
|
||||
from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
|
||||
|
||||
|
||||
@@ -38,15 +38,16 @@ mod _ssl {
|
||||
},
|
||||
socket::{self, PySocket},
|
||||
vm::{
|
||||
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
|
||||
class_or_notimplemented,
|
||||
convert::{ToPyException, ToPyObject},
|
||||
exceptions,
|
||||
function::{
|
||||
ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath,
|
||||
OptionalArg,
|
||||
OptionalArg, PyComparisonValue,
|
||||
},
|
||||
types::Constructor,
|
||||
types::{Comparable, Constructor, PyComparisonOp},
|
||||
utils::ToCString,
|
||||
},
|
||||
};
|
||||
@@ -162,6 +163,8 @@ mod _ssl {
|
||||
const HAS_TLSv1_2: bool = true;
|
||||
#[pyattr]
|
||||
const HAS_TLSv1_3: bool = cfg!(ossl111);
|
||||
#[pyattr]
|
||||
const HAS_PSK: bool = true;
|
||||
|
||||
// the openssl version from the API headers
|
||||
|
||||
@@ -816,17 +819,46 @@ mod _ssl {
|
||||
let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone()))
|
||||
.map_err(|e| convert_openssl_error(vm, e))?;
|
||||
|
||||
// TODO: use this
|
||||
let _ = args.session;
|
||||
|
||||
Ok(PySslSocket {
|
||||
let py_ssl_socket = PySslSocket {
|
||||
ctx: zelf,
|
||||
stream: PyRwLock::new(stream),
|
||||
socket_type,
|
||||
server_hostname: args.server_hostname,
|
||||
owner: PyRwLock::new(args.owner.map(|o| o.downgrade(None, vm)).transpose()?),
|
||||
})
|
||||
};
|
||||
|
||||
// Set session if provided
|
||||
if let Some(session) = args.session
|
||||
&& !vm.is_none(&session)
|
||||
{
|
||||
py_ssl_socket.set_session(session, vm)?;
|
||||
}
|
||||
|
||||
Ok(py_ssl_socket)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn _wrap_bio(_zelf: PyRef<Self>, _args: WrapBioArgs, vm: &VirtualMachine) -> PyResult {
|
||||
// TODO: Implement BIO-based SSL wrapping
|
||||
// This requires refactoring PySslSocket to support both socket and BIO modes
|
||||
Err(vm.new_not_implemented_error(
|
||||
"_wrap_bio is not yet implemented in RustPython".to_owned(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
#[allow(dead_code)] // Fields will be used when _wrap_bio is fully implemented
|
||||
struct WrapBioArgs {
|
||||
incoming: PyRef<PySslMemoryBio>,
|
||||
outgoing: PyRef<PySslMemoryBio>,
|
||||
server_side: bool,
|
||||
#[pyarg(any, default)]
|
||||
server_hostname: Option<PyStrRef>,
|
||||
#[pyarg(named, default)]
|
||||
owner: Option<PyObjectRef>,
|
||||
#[pyarg(named, default)]
|
||||
session: Option<PyObjectRef>,
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
@@ -996,6 +1028,19 @@ mod _ssl {
|
||||
.transpose()
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn get_unverified_chain(&self, vm: &VirtualMachine) -> Option<PyObjectRef> {
|
||||
let stream = self.stream.read();
|
||||
let chain = stream.ssl().peer_cert_chain()?;
|
||||
|
||||
let certs: Vec<PyObjectRef> = chain
|
||||
.iter()
|
||||
.filter_map(|cert| cert.to_der().ok().map(|der| vm.ctx.new_bytes(der).into()))
|
||||
.collect();
|
||||
|
||||
Some(vm.ctx.new_list(certs).into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn version(&self) -> Option<&'static str> {
|
||||
let v = self.stream.read().ssl().version_str();
|
||||
@@ -1103,6 +1148,73 @@ mod _ssl {
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn session(&self, _vm: &VirtualMachine) -> PyResult<Option<PySslSession>> {
|
||||
let stream = self.stream.read();
|
||||
unsafe {
|
||||
let session_ptr = sys::SSL_get_session(stream.ssl().as_ptr());
|
||||
if session_ptr.is_null() {
|
||||
Ok(None)
|
||||
} else {
|
||||
// Increment reference count since SSL_get_session returns a borrowed reference
|
||||
#[cfg(ossl110)]
|
||||
let _session = sys::SSL_SESSION_up_ref(session_ptr);
|
||||
|
||||
Ok(Some(PySslSession {
|
||||
session: session_ptr,
|
||||
ctx: self.ctx.clone(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set_session(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
// Check if value is SSLSession type
|
||||
let session = value
|
||||
.downcast_ref::<PySslSession>()
|
||||
.ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?;
|
||||
|
||||
// Check if session refers to the same SSLContext
|
||||
if !std::ptr::eq(
|
||||
self.ctx.ctx.read().as_ptr(),
|
||||
session.ctx.ctx.read().as_ptr(),
|
||||
) {
|
||||
return Err(
|
||||
vm.new_value_error("Session refers to a different SSLContext.".to_owned())
|
||||
);
|
||||
}
|
||||
|
||||
// Check if this is a client socket
|
||||
if self.socket_type != SslServerOrClient::Client {
|
||||
return Err(
|
||||
vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned())
|
||||
);
|
||||
}
|
||||
|
||||
// Check if handshake is not finished
|
||||
let stream = self.stream.read();
|
||||
unsafe {
|
||||
if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 {
|
||||
return Err(
|
||||
vm.new_value_error("Cannot set session after handshake.".to_owned())
|
||||
);
|
||||
}
|
||||
|
||||
if sys::SSL_set_session(stream.ssl().as_ptr(), session.session) == 0 {
|
||||
return Err(convert_openssl_error(vm, ErrorStack::get()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn session_reused(&self) -> bool {
|
||||
let stream = self.stream.read();
|
||||
unsafe { sys::SSL_session_reused(stream.ssl().as_ptr()) != 0 }
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn read(
|
||||
&self,
|
||||
@@ -1164,6 +1276,282 @@ mod _ssl {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyattr]
|
||||
#[pyclass(module = "ssl", name = "SSLSession")]
|
||||
#[derive(PyPayload)]
|
||||
struct PySslSession {
|
||||
session: *mut sys::SSL_SESSION,
|
||||
ctx: PyRef<PySslContext>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for PySslSession {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.pad("SSLSession")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PySslSession {
|
||||
fn drop(&mut self) {
|
||||
if !self.session.is_null() {
|
||||
unsafe {
|
||||
sys::SSL_SESSION_free(self.session);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for PySslSession {}
|
||||
unsafe impl Sync for PySslSession {}
|
||||
|
||||
impl Comparable for PySslSession {
|
||||
fn cmp(
|
||||
zelf: &Py<Self>,
|
||||
other: &crate::vm::PyObject,
|
||||
op: PyComparisonOp,
|
||||
_vm: &VirtualMachine,
|
||||
) -> PyResult<PyComparisonValue> {
|
||||
let other = class_or_notimplemented!(Self, other);
|
||||
|
||||
if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) {
|
||||
return Ok(PyComparisonValue::NotImplemented);
|
||||
}
|
||||
let mut eq = unsafe {
|
||||
let mut self_len: libc::c_uint = 0;
|
||||
let mut other_len: libc::c_uint = 0;
|
||||
let self_id = sys::SSL_SESSION_get_id(zelf.session, &mut self_len);
|
||||
let other_id = sys::SSL_SESSION_get_id(other.session, &mut other_len);
|
||||
|
||||
if self_len != other_len {
|
||||
false
|
||||
} else {
|
||||
let self_slice = std::slice::from_raw_parts(self_id, self_len as usize);
|
||||
let other_slice = std::slice::from_raw_parts(other_id, other_len as usize);
|
||||
self_slice == other_slice
|
||||
}
|
||||
};
|
||||
if matches!(op, PyComparisonOp::Ne) {
|
||||
eq = !eq;
|
||||
}
|
||||
Ok(PyComparisonValue::Implemented(eq))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyattr]
|
||||
#[pyclass(module = "ssl", name = "MemoryBIO")]
|
||||
#[derive(PyPayload)]
|
||||
struct PySslMemoryBio {
|
||||
bio: *mut sys::BIO,
|
||||
eof_written: AtomicCell<bool>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for PySslMemoryBio {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.pad("MemoryBIO")
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PySslMemoryBio {
|
||||
fn drop(&mut self) {
|
||||
if !self.bio.is_null() {
|
||||
unsafe {
|
||||
sys::BIO_free_all(self.bio);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for PySslMemoryBio {}
|
||||
unsafe impl Sync for PySslMemoryBio {}
|
||||
|
||||
// OpenSSL BIO helper functions
|
||||
// These are typically macros in OpenSSL, implemented via BIO_ctrl
|
||||
const BIO_CTRL_PENDING: libc::c_int = 10;
|
||||
const BIO_CTRL_SET_EOF: libc::c_int = 2;
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn BIO_ctrl_pending(bio: *mut sys::BIO) -> usize {
|
||||
unsafe { sys::BIO_ctrl(bio, BIO_CTRL_PENDING, 0, std::ptr::null_mut()) as usize }
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn BIO_set_mem_eof_return(bio: *mut sys::BIO, eof: libc::c_int) -> libc::c_int {
|
||||
unsafe {
|
||||
sys::BIO_ctrl(
|
||||
bio,
|
||||
BIO_CTRL_SET_EOF,
|
||||
eof as libc::c_long,
|
||||
std::ptr::null_mut(),
|
||||
) as libc::c_int
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
unsafe fn BIO_clear_retry_flags(bio: *mut sys::BIO) {
|
||||
unsafe {
|
||||
sys::BIO_clear_flags(bio, sys::BIO_FLAGS_RWS | sys::BIO_FLAGS_SHOULD_RETRY);
|
||||
}
|
||||
}
|
||||
|
||||
impl Constructor for PySslMemoryBio {
|
||||
type Args = ();
|
||||
|
||||
fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult {
|
||||
unsafe {
|
||||
let bio = sys::BIO_new(sys::BIO_s_mem());
|
||||
if bio.is_null() {
|
||||
return Err(vm.new_memory_error("failed to allocate BIO".to_owned()));
|
||||
}
|
||||
|
||||
sys::BIO_set_retry_read(bio);
|
||||
BIO_set_mem_eof_return(bio, -1);
|
||||
|
||||
PySslMemoryBio {
|
||||
bio,
|
||||
eof_written: AtomicCell::new(false),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
.map(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(with(Constructor))]
|
||||
impl PySslMemoryBio {
|
||||
#[pygetset]
|
||||
fn pending(&self) -> usize {
|
||||
unsafe { BIO_ctrl_pending(self.bio) }
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn eof(&self) -> bool {
|
||||
let pending = unsafe { BIO_ctrl_pending(self.bio) };
|
||||
pending == 0 && self.eof_written.load()
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn read(&self, size: OptionalArg<i32>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
|
||||
unsafe {
|
||||
let avail = BIO_ctrl_pending(self.bio).min(i32::MAX as usize) as i32;
|
||||
let len = size.unwrap_or(-1);
|
||||
let len = if len < 0 || len > avail { avail } else { len };
|
||||
|
||||
if len == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut buf = vec![0u8; len as usize];
|
||||
let nbytes = sys::BIO_read(self.bio, buf.as_mut_ptr() as *mut _, len);
|
||||
|
||||
if nbytes < 0 {
|
||||
return Err(convert_openssl_error(vm, ErrorStack::get()));
|
||||
}
|
||||
|
||||
buf.truncate(nbytes as usize);
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn write(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult<i32> {
|
||||
if self.eof_written.load() {
|
||||
return Err(vm.new_exception_msg(
|
||||
ssl_error(vm),
|
||||
"cannot write() after write_eof()".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
data.with_ref(|buf| unsafe {
|
||||
if buf.len() > i32::MAX as usize {
|
||||
return Err(
|
||||
vm.new_overflow_error(format!("string longer than {} bytes", i32::MAX))
|
||||
);
|
||||
}
|
||||
|
||||
let nbytes = sys::BIO_write(self.bio, buf.as_ptr() as *const _, buf.len() as i32);
|
||||
if nbytes < 0 {
|
||||
return Err(convert_openssl_error(vm, ErrorStack::get()));
|
||||
}
|
||||
|
||||
Ok(nbytes)
|
||||
})
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn write_eof(&self) {
|
||||
self.eof_written.store(true);
|
||||
unsafe {
|
||||
BIO_clear_retry_flags(self.bio);
|
||||
BIO_set_mem_eof_return(self.bio, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(with(Comparable))]
|
||||
impl PySslSession {
|
||||
#[pygetset]
|
||||
fn time(&self) -> i64 {
|
||||
unsafe {
|
||||
#[cfg(ossl330)]
|
||||
{
|
||||
sys::SSL_SESSION_get_time(self.session) as i64
|
||||
}
|
||||
#[cfg(not(ossl330))]
|
||||
{
|
||||
sys::SSL_SESSION_get_time(self.session) as i64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn timeout(&self) -> i64 {
|
||||
unsafe { sys::SSL_SESSION_get_timeout(self.session) as i64 }
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn ticket_lifetime_hint(&self) -> u64 {
|
||||
// SSL_SESSION_get_ticket_lifetime_hint may not be available in older OpenSSL
|
||||
// Return 0 as default if not available
|
||||
#[cfg(ossl110)]
|
||||
{
|
||||
// For now, return 0 as this function may not be in openssl-sys
|
||||
let _ = self.session;
|
||||
0
|
||||
}
|
||||
#[cfg(not(ossl110))]
|
||||
{
|
||||
let _ = self.session;
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||||
unsafe {
|
||||
let mut len: libc::c_uint = 0;
|
||||
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
|
||||
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
|
||||
vm.ctx.new_bytes(id_slice.to_vec()).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn has_ticket(&self) -> bool {
|
||||
// SSL_SESSION_has_ticket may not be available in older OpenSSL
|
||||
// Return false as default
|
||||
#[cfg(ossl110)]
|
||||
{
|
||||
// For now, return false as this function may not be in openssl-sys
|
||||
let _ = self.session;
|
||||
false
|
||||
}
|
||||
#[cfg(not(ossl110))]
|
||||
{
|
||||
let _ = self.session;
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
|
||||
let cls = ssl_error(vm);
|
||||
|
||||
Reference in New Issue
Block a user