Merge pull request #3147 from youknowone/ssl-module

Refactor ssl module
This commit is contained in:
Jeong YunWon
2021-09-27 17:12:26 +09:00
committed by GitHub
2 changed files with 223 additions and 198 deletions

View File

@@ -5,7 +5,7 @@ use crate::{
exceptions::{IntoPyException, PyBaseExceptionRef},
function::{FuncArgs, OptionalArg, OptionalOption},
utils::{Either, ToCString},
IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromBorrowedObject,
IntoPyObject, PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromBorrowedObject,
TryFromObject, TypeProtocol, VirtualMachine,
};
use crossbeam_utils::atomic::AtomicCell;
@@ -18,7 +18,10 @@ use std::convert::TryFrom;
use std::mem::MaybeUninit;
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
use std::time::{Duration, Instant};
use std::{ffi, io};
use std::{
ffi,
io::{self, Read, Write},
};
#[cfg(unix)]
type RawSocket = std::os::unix::io::RawFd;
@@ -158,13 +161,26 @@ impl Default for PySocket {
}
}
pub type PySocketRef = PyRef<PySocket>;
#[cfg(windows)]
const CLOSED_ERR: i32 = c::WSAENOTSOCK;
#[cfg(unix)]
const CLOSED_ERR: i32 = c::EBADF;
#[pyimpl(flags(BASETYPE))]
impl Read for &PySocket {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&mut &*self.sock_io()?).read(buf)
}
}
impl Write for &PySocket {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
(&mut &*self.sock_io()?).write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
(&mut &*self.sock_io()?).flush()
}
}
impl PySocket {
pub fn sock_opt(&self) -> Option<PyMappedRwLockReadGuard<'_, Socket>> {
PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.get()).ok()
@@ -179,94 +195,6 @@ impl PySocket {
self.sock_io().map_err(|e| e.into_pyexception(vm))
}
#[pyslot]
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
Self::default().into_pyresult_with_type(vm, cls)
}
#[pymethod(magic)]
fn init(
&self,
family: OptionalArg<i32>,
socket_kind: OptionalArg<i32>,
proto: OptionalArg<i32>,
fileno: OptionalOption<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut family = family.unwrap_or(-1);
let mut socket_kind = socket_kind.unwrap_or(-1);
let mut proto = proto.unwrap_or(-1);
let fileno = fileno
.flatten()
.map(|obj| get_raw_sock(obj, vm))
.transpose()?;
let sock;
if let Some(fileno) = fileno {
sock = sock_from_raw(fileno, vm)?;
match sock.local_addr() {
Ok(addr) if family == -1 => family = addr.family() as i32,
Err(e)
if family == -1
|| matches!(
e.raw_os_error(),
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
) =>
{
std::mem::forget(sock);
return Err(e.into_pyexception(vm));
}
_ => {}
}
if socket_kind == -1 {
// TODO: when socket2 cuts a new release, type will be available on all os
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
let res = unsafe {
c::getsockopt(
sock_fileno(&sock) as _,
c::SOL_SOCKET,
c::SO_TYPE,
&mut socket_kind as *mut libc::c_int as *mut _,
&mut (std::mem::size_of::<i32>() as _),
)
};
if res < 0 {
return Err(super::os::errno_err(vm));
}
}
cfg_if::cfg_if! {
if #[cfg(any(
target_os = "android",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
))] {
if proto == -1 {
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
}
} else {
proto = 0;
}
}
} else {
if family == -1 {
family = c::AF_INET as i32
}
if socket_kind == -1 {
socket_kind = c::SOCK_STREAM
}
if proto == -1 {
proto = 0
}
sock = Socket::new(
Domain::from(family),
SocketType::from(socket_kind),
Some(Protocol::from(proto)),
)
.map_err(|err| err.into_pyexception(vm))?;
};
self.init_inner(family, socket_kind, proto, sock, vm)
}
fn init_inner(
&self,
family: i32,
@@ -523,6 +451,97 @@ impl PySocket {
Err(err.into())
}
}
}
#[pyimpl(flags(BASETYPE))]
impl PySocket {
#[pyslot]
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
Self::default().into_pyresult_with_type(vm, cls)
}
#[pymethod(magic)]
fn init(
&self,
family: OptionalArg<i32>,
socket_kind: OptionalArg<i32>,
proto: OptionalArg<i32>,
fileno: OptionalOption<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut family = family.unwrap_or(-1);
let mut socket_kind = socket_kind.unwrap_or(-1);
let mut proto = proto.unwrap_or(-1);
let fileno = fileno
.flatten()
.map(|obj| get_raw_sock(obj, vm))
.transpose()?;
let sock;
if let Some(fileno) = fileno {
sock = sock_from_raw(fileno, vm)?;
match sock.local_addr() {
Ok(addr) if family == -1 => family = addr.family() as i32,
Err(e)
if family == -1
|| matches!(
e.raw_os_error(),
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
) =>
{
std::mem::forget(sock);
return Err(e.into_pyexception(vm));
}
_ => {}
}
if socket_kind == -1 {
// TODO: when socket2 cuts a new release, type will be available on all os
// socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into();
let res = unsafe {
c::getsockopt(
sock_fileno(&sock) as _,
c::SOL_SOCKET,
c::SO_TYPE,
&mut socket_kind as *mut libc::c_int as *mut _,
&mut (std::mem::size_of::<i32>() as _),
)
};
if res < 0 {
return Err(super::os::errno_err(vm));
}
}
cfg_if::cfg_if! {
if #[cfg(any(
target_os = "android",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "linux",
))] {
if proto == -1 {
proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into);
}
} else {
proto = 0;
}
}
} else {
if family == -1 {
family = c::AF_INET as i32
}
if socket_kind == -1 {
socket_kind = c::SOCK_STREAM
}
if proto == -1 {
proto = 0
}
sock = Socket::new(
Domain::from(family),
SocketType::from(socket_kind),
Some(Protocol::from(proto)),
)
.map_err(|err| err.into_pyexception(vm))?;
};
self.init_inner(family, socket_kind, proto, sock, vm)
}
#[pymethod]
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
@@ -919,20 +938,6 @@ impl PySocket {
}
}
impl io::Read for PySocketRef {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
<&Socket as io::Read>::read(&mut &*self.sock_io()?, buf)
}
}
impl io::Write for PySocketRef {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
<&Socket as io::Write>::write(&mut &*self.sock_io()?, buf)
}
fn flush(&mut self) -> io::Result<()> {
<&Socket as io::Write>::flush(&mut &*self.sock_io()?)
}
}
struct Address {
host: PyStrRef,
port: u16,

View File

@@ -1,4 +1,4 @@
use super::socket::{self, PySocketRef};
use super::socket::{self, PySocket};
use crate::common::lock::{PyRwLock, PyRwLockWriteGuard};
use crate::{
builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef},
@@ -22,6 +22,7 @@ use openssl::{
use std::convert::TryFrom;
use std::ffi::CStr;
use std::fmt;
use std::io::{Read, Write};
use std::time::Instant;
mod sys {
@@ -105,9 +106,9 @@ fn nid2obj(nid: Nid) -> Option<Asn1Object> {
unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) }
}
fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
unsafe {
let no_name = if no_name { 1 } else { 0 };
let ptr = obj.as_ptr();
let no_name = if no_name { 1 } else { 0 };
let ptr = obj.as_ptr();
let s = unsafe {
let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name);
assert!(buflen >= 0);
if buflen == 0 {
@@ -116,10 +117,10 @@ fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option<String> {
let mut buf = vec![0u8; buflen as usize];
let ret = sys::OBJ_obj2txt(buf.as_mut_ptr() as *mut libc::c_char, buflen, ptr, no_name);
assert!(ret >= 0);
let s = String::from_utf8(buf)
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned());
Some(s)
}
String::from_utf8(buf)
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
};
Some(s)
}
type PyNid = (libc::c_int, String, String, Option<String>);
@@ -232,9 +233,8 @@ fn _ssl_rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
return Err(vm.new_value_error("num must be positive".to_owned()));
}
let mut buf = vec![0; n as usize];
openssl::rand::rand_bytes(&mut buf)
.map(|()| buf)
.map_err(|e| convert_openssl_error(vm, e))
openssl::rand::rand_bytes(&mut buf).map_err(|e| convert_openssl_error(vm, e))?;
Ok(buf)
}
fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool)> {
@@ -592,7 +592,7 @@ impl PySslContext {
}
}
let stream = ssl::SslStream::new(ssl, args.sock.clone())
let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone()))
.map_err(|e| convert_openssl_error(vm, e))?;
// TODO: use this
@@ -611,7 +611,7 @@ impl PySslContext {
#[derive(FromArgs)]
struct WrapSocketArgs {
#[pyarg(any)]
sock: PySocketRef,
sock: PyRef<PySocket>,
#[pyarg(any)]
server_side: bool,
#[pyarg(any, default)]
@@ -642,16 +642,9 @@ struct LoadCertChainArgs {
password: Option<Either<PyStrRef, ArgCallable>>,
}
struct SocketTimeout {
// Err is true if the socket is blocking
deadline: Result<Instant, bool>,
}
impl SocketTimeout {
fn get(s: &socket::PySocket) -> Self {
let deadline = s.get_timeout().map(|d| Instant::now() + d);
Self { deadline }
}
}
// Err is true if the socket is blocking
type SocketDeadline = Result<Instant, bool>;
enum SelectRet {
Nonblocking,
TimedOut,
@@ -659,50 +652,60 @@ enum SelectRet {
Closed,
Ok,
}
fn ssl_select(sock: &socket::PySocket, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet {
let sock = match sock.sock_opt() {
Some(s) => s,
None => return SelectRet::Closed,
};
let timeout = match &timeout.deadline {
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
Some(timeout) => timeout,
None => return SelectRet::TimedOut,
},
Err(true) => return SelectRet::IsBlocking,
Err(false) => return SelectRet::Nonblocking,
};
let res = socket::sock_select(
&sock,
match needs {
SslNeeds::Read => socket::SelectKind::Read,
SslNeeds::Write => socket::SelectKind::Write,
},
Some(timeout),
);
match res {
Ok(true) => SelectRet::TimedOut,
_ => SelectRet::Ok,
}
}
#[derive(Clone, Copy)]
enum SslNeeds {
Read,
Write,
}
fn socket_needs(
err: &ssl::Error,
sock: &socket::PySocket,
timeout: &SocketTimeout,
) -> (Option<SslNeeds>, SelectRet) {
let needs = match err.code() {
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
_ => None,
};
let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, timeout));
(needs, state)
struct SocketStream(PyRef<PySocket>);
impl SocketStream {
fn timeout_deadline(&self) -> SocketDeadline {
self.0.get_timeout().map(|d| Instant::now() + d)
}
fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet {
let sock = match self.0.sock_opt() {
Some(s) => s,
None => return SelectRet::Closed,
};
let deadline = match &deadline {
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
Some(deadline) => deadline,
None => return SelectRet::TimedOut,
},
Err(true) => return SelectRet::IsBlocking,
Err(false) => return SelectRet::Nonblocking,
};
let res = socket::sock_select(
&sock,
match needs {
SslNeeds::Read => socket::SelectKind::Read,
SslNeeds::Write => socket::SelectKind::Write,
},
Some(deadline),
);
match res {
Ok(true) => SelectRet::TimedOut,
_ => SelectRet::Ok,
}
}
fn socket_needs(
&self,
err: &ssl::Error,
deadline: &SocketDeadline,
) -> (Option<SslNeeds>, SelectRet) {
let needs = match err.code() {
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
_ => None,
};
let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline));
(needs, state)
}
}
fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
@@ -716,7 +719,7 @@ fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
#[derive(PyValue)]
struct PySslSocket {
ctx: PyRef<PySslContext>,
stream: PyRwLock<ssl::SslStream<PySocketRef>>,
stream: PyRwLock<ssl::SslStream<SocketStream>>,
socket_type: SslServerOrClient,
server_hostname: Option<PyStrRef>,
owner: PyRwLock<Option<PyWeak>>,
@@ -788,38 +791,37 @@ impl PySslSocket {
.map(cipher_to_tuple)
}
#[cfg(osslconf = "OPENSSL_NO_COMP")]
#[pymethod]
fn compression(&self) -> Option<&'static str> {
#[cfg(osslconf = "OPENSSL_NO_COMP")]
{
None
None
}
#[cfg(not(osslconf = "OPENSSL_NO_COMP"))]
#[pymethod]
fn compression(&self) -> Option<&'static str> {
let stream = self.stream.read();
let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) };
if comp_method.is_null() {
return None;
}
#[cfg(not(osslconf = "OPENSSL_NO_COMP"))]
{
let stream = self.stream.read();
let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) };
if comp_method.is_null() {
return None;
}
let typ = unsafe { sys::COMP_get_type(comp_method) };
let nid = Nid::from_raw(typ);
if nid == Nid::UNDEF {
return None;
}
nid.short_name().ok()
let typ = unsafe { sys::COMP_get_type(comp_method) };
let nid = Nid::from_raw(typ);
if nid == Nid::UNDEF {
return None;
}
nid.short_name().ok()
}
#[pymethod]
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
let mut stream = self.stream.write();
let timeout = SocketTimeout::get(stream.get_ref());
let timeout = stream.get_ref().timeout_deadline();
loop {
let err = match stream.do_handshake() {
Ok(()) => return Ok(()),
Err(e) => e,
};
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
@@ -844,8 +846,8 @@ impl PySslSocket {
let mut stream = self.stream.write();
let data = data.borrow_buf();
let data = &*data;
let timeout = SocketTimeout::get(stream.get_ref());
let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout);
let timeout = stream.get_ref().timeout_deadline();
let state = stream.get_ref().select(SslNeeds::Write, &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
@@ -861,7 +863,7 @@ impl PySslSocket {
Ok(len) => return Ok(len),
Err(e) => e,
};
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
@@ -902,7 +904,7 @@ impl PySslSocket {
Some(b) => b,
None => buf,
};
let timeout = SocketTimeout::get(stream.get_ref());
let timeout = stream.get_ref().timeout_deadline();
let count = loop {
let err = match stream.ssl_read(buf) {
Ok(count) => break count,
@@ -913,7 +915,7 @@ impl PySslSocket {
{
break 0;
}
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
let (needs, state) = stream.get_ref().socket_needs(&err, &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
@@ -996,10 +998,9 @@ fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple {
}
fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
if binary {
cert.to_der()
.map(|b| vm.ctx.new_bytes(b))
.map_err(|e| convert_openssl_error(vm, e))
let r = if binary {
let b = cert.to_der().map_err(|e| convert_openssl_error(vm, e))?;
vm.ctx.new_bytes(b)
} else {
let dict = vm.ctx.new_dict();
@@ -1073,8 +1074,9 @@ fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult {
dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?;
};
Ok(dict.into_object())
}
dict.into_object()
};
Ok(r)
}
#[allow(non_snake_case)]
@@ -1237,3 +1239,21 @@ fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) {
#[cfg(not(windows))]
fn extend_module_platform_specific(_module: &PyObjectRef, _vm: &VirtualMachine) {}
impl Read for SocketStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut socket: &PySocket = &self.0;
socket.read(buf)
}
}
impl Write for SocketStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut socket: &PySocket = &self.0;
socket.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
let mut socket: &PySocket = &self.0;
socket.flush()
}
}