Add ctx.load_verify_locations(cadata=), and ctx.get_ca_certs

This commit is contained in:
Noah
2020-03-22 17:06:53 -05:00
parent 126f41e003
commit 4f64afb8cf
3 changed files with 114 additions and 24 deletions

2
Cargo.lock generated
View File

@@ -1600,7 +1600,7 @@ dependencies = [
"flame",
"flamer",
"flate2",
"foreign-types-shared",
"foreign-types",
"gethostname",
"getrandom",
"hex",

View File

@@ -70,7 +70,7 @@ paste = "0.1"
base64 = "0.11"
is-macro = "0.1"
result-like = "^0.2.1"
foreign-types-shared = "0.1"
foreign-types = "0.3"
num_enum = "0.4"
flame = { version = "0.2", optional = true }

View File

@@ -12,19 +12,22 @@ use crate::pyobject::{
use crate::types::create_type;
use crate::VirtualMachine;
use std::cell::{RefCell, RefMut};
use std::cell::{Ref, RefCell, RefMut};
use std::convert::TryFrom;
use std::ffi::{CStr, CString};
use std::fmt;
use foreign_types_shared::{ForeignType, ForeignTypeRef};
use foreign_types::{ForeignType, ForeignTypeRef};
use openssl::{
asn1::{Asn1Object, Asn1ObjectRef},
error::ErrorStack,
nid::Nid,
ssl::{self, SslContextBuilder, SslOptions, SslVerifyMode},
x509::{X509Ref, X509},
};
mod sys {
#![allow(non_camel_case_types, unused)]
use libc::{c_char, c_double, c_int, c_void};
pub use openssl_sys::*;
extern "C" {
@@ -40,7 +43,54 @@ mod sys {
pub fn SSL_CTX_set_post_handshake_auth(ctx: *mut SSL_CTX, val: c_int);
pub fn RAND_add(buf: *const c_void, num: c_int, randomness: c_double);
pub fn RAND_pseudo_bytes(buf: *const u8, num: c_int) -> c_int;
pub fn X509_STORE_get0_objects(ctx: *mut X509_STORE) -> *mut stack_st_X509_OBJECT;
pub fn X509_OBJECT_free(a: *mut X509_OBJECT);
}
pub enum stack_st_X509_OBJECT {}
pub type X509_LOOKUP_TYPE = c_int;
pub const X509_LU_NONE: X509_LOOKUP_TYPE = 0;
pub const X509_LU_X509: X509_LOOKUP_TYPE = 1;
pub const X509_LU_CRL: X509_LOOKUP_TYPE = 2;
#[repr(C)]
pub struct X509_OBJECT {
pub r#type: X509_LOOKUP_TYPE,
pub data: X509_OBJECT_data,
}
#[repr(C)]
pub union X509_OBJECT_data {
pub ptr: *mut c_char,
pub x509: *mut X509,
pub crl: *mut X509_CRL,
pub pkey: *mut EVP_PKEY,
}
}
// TODO: upstream this into rust-openssl
foreign_types::foreign_type! {
type CType = sys::X509_OBJECT;
fn drop = sys::X509_OBJECT_free;
pub struct X509Object;
pub struct X509ObjectRef;
}
impl X509ObjectRef {
fn x509(&self) -> Option<&X509Ref> {
let ptr = self.as_ptr();
let ty = unsafe { (*ptr).r#type };
if ty == sys::X509_LU_X509 {
Some(unsafe { X509Ref::from_ptr((*ptr).data.x509) })
} else {
None
}
}
}
impl openssl::stack::Stackable for X509Object {
type StackType = sys::stack_st_X509_OBJECT;
}
#[derive(num_enum::IntoPrimitive, num_enum::TryFromPrimitive, PartialEq)]
@@ -224,7 +274,7 @@ fn ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec<u8>, bool
let ret = unsafe { sys::RAND_pseudo_bytes(buf.as_mut_ptr(), n) };
match ret {
0 | 1 => Ok((buf, ret == 1)),
_ => Err(convert_openssl_error(vm, openssl::error::ErrorStack::get())),
_ => Err(convert_openssl_error(vm, ErrorStack::get())),
}
}
@@ -251,11 +301,11 @@ impl PySslContext {
fn builder(&self) -> RefMut<SslContextBuilder> {
self.ctx.borrow_mut()
}
// fn ctx(&self) -> Ref<SslContextRef> {
// Ref::map(self.ctx.borrow(), |ctx| unsafe {
// SslContextRef::from_ptr(ctx.as_ptr())
// })
// }
fn ctx(&self) -> Ref<ssl::SslContextRef> {
Ref::map(self.ctx.borrow(), |ctx| unsafe {
&**(ctx as *const SslContextBuilder as *const ssl::SslContext)
})
}
fn ptr(&self) -> *mut sys::SSL_CTX {
self.ctx.borrow().as_ptr()
}
@@ -374,8 +424,23 @@ impl PySslContext {
);
}
if let Some(_cadata) = args.cadata {
todo!()
if let Some(cadata) = args.cadata {
let cert = match cadata {
Either::A(s) => {
if !s.as_str().is_ascii() {
return Err(vm.new_type_error("Must be an ascii string".to_owned()));
}
X509::from_pem(s.as_str().as_bytes())
}
Either::B(b) => b.with_ref(X509::from_der),
};
let cert = cert.map_err(|e| convert_openssl_error(vm, e))?;
let ctx = self.ctx();
let store = ctx.cert_store();
let ret = unsafe { sys::X509_STORE_add_cert(store.as_ptr(), cert.as_ptr()) };
if ret <= 0 {
return Err(convert_openssl_error(vm, ErrorStack::get()));
}
}
if args.cafile.is_some() || args.capath.is_some() {
@@ -395,7 +460,7 @@ impl PySslContext {
let err = if errno != 0 {
super::os::errno_err(vm)
} else {
convert_openssl_error(vm, openssl::error::ErrorStack::get())
convert_openssl_error(vm, ErrorStack::get())
};
return Err(err);
}
@@ -404,6 +469,32 @@ impl PySslContext {
Ok(())
}
#[pymethod]
fn get_ca_certs(&self, binary_form: OptionalArg<bool>, vm: &VirtualMachine) -> PyResult {
use openssl::stack::StackRef;
let binary_form = binary_form.unwrap_or(false);
let certs = unsafe {
let stack = sys::X509_STORE_get0_objects(self.ctx().cert_store().as_ptr());
assert!(!stack.is_null());
StackRef::<X509Object>::from_ptr(stack)
};
let certs = certs
.iter()
.filter_map(|cert| {
let cert = cert.x509()?;
let obj = if binary_form {
cert.to_der()
.map(|b| vm.ctx.new_bytes(b))
.map_err(|e| convert_openssl_error(vm, e))
} else {
todo!()
};
Some(obj)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(vm.ctx.new_list(certs))
}
#[pymethod]
fn _wrap_socket(
zelf: PyRef<Self>,
@@ -475,7 +566,7 @@ struct LoadVerifyLocationsArgs {
#[pyarg(positional_or_keyword, default = "None")]
capath: Option<CString>,
#[pyarg(positional_or_keyword, default = "None")]
cadata: Option<PyStringRef>,
cadata: Option<Either<PyStringRef, PyBytesLike>>,
}
#[pyclass(name = "_SSLSocket")]
@@ -591,19 +682,18 @@ fn ssl_error(vm: &VirtualMachine) -> PyClassRef {
vm.class("_ssl", "SSLError")
}
fn convert_openssl_error(
vm: &VirtualMachine,
err: openssl::error::ErrorStack,
) -> PyBaseExceptionRef {
fn convert_openssl_error(vm: &VirtualMachine, err: ErrorStack) -> PyBaseExceptionRef {
let cls = ssl_error(vm);
match err.errors().first() {
Some(e) => {
let no = "unknown";
let msg = format!(
"openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}",
e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(),
e.reason().unwrap_or(no), e.data().unwrap_or("none"),
);
// let no = "unknown";
// let msg = format!(
// "openssl error code {}, from library {}, in function {}, on line {}, with reason {}, and extra data {}",
// e.code(), e.library().unwrap_or(no), e.function().unwrap_or(no), e.line(),
// e.reason().unwrap_or(no), e.data().unwrap_or("none"),
// );
// TODO: map the error codes to code names, e.g. "CERTIFICATE_VERIFY_FAILED", just requires a big hashmap/dict
let msg = e.to_string();
vm.new_exception_msg(cls, msg)
}
None => vm.new_exception_empty(cls),