diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 0f1aa31cb..9171ad0ea 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -1,8 +1,7 @@ -use std::char; -use std::fmt; use std::mem::size_of; use std::ops::Range; use std::string::ToString; +use std::{char, ffi, fmt}; use itertools::Itertools; use num_traits::ToPrimitive; @@ -25,7 +24,7 @@ use crate::utils::Either; use crate::VirtualMachine; use crate::{ IdProtocol, IntoPyObject, ItemProtocol, PyClassImpl, PyComparisonValue, PyContext, PyIterable, - PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TryIntoRef, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, }; use rustpython_common::atomic::{self, PyAtomic, Radium}; use rustpython_common::hash; @@ -319,6 +318,10 @@ impl PyStr { self.char_len() == self.byte_len() } + pub fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { + ffi::CString::new(self.as_str()).map_err(|err| err.into_pyexception(vm)) + } + #[pymethod(name = "__sizeof__")] fn sizeof(&self) -> usize { size_of::() + self.as_str().len() * size_of::() @@ -1168,23 +1171,6 @@ impl IntoPyObject for &String { } } -impl TryFromObject for std::ffi::CString { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let s = PyStrRef::try_from_object(vm, obj)?; - Self::new(s.as_str().to_owned()) - .map_err(|_| vm.new_value_error("embedded null character".to_owned())) - } -} - -impl TryFromObject for std::ffi::OsString { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - use std::str::FromStr; - - let s = PyStrRef::try_from_object(vm, obj)?; - Ok(std::ffi::OsString::from_str(s.as_str()).unwrap()) - } -} - type SplitArgs<'a> = anystr::SplitArgs<'a, PyStrRef>; #[derive(FromArgs)] diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index e1d9465ea..e1b27ac15 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -832,3 +832,20 @@ impl serde::Serialize for SerializeException<'_> { struc.end() } } + +pub(crate) fn cstring_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_value_error("embedded null character".to_owned()) +} + +impl IntoPyException for std::ffi::NulError { + fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { + cstring_error(vm) + } +} + +#[cfg(windows)] +impl IntoPyException for widestring::NulError { + fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { + cstring_error(vm) + } +} diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index a5d0f84d6..ef3acf607 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -99,14 +99,12 @@ impl PyPathLike { #[cfg(any(unix, target_os = "wasi"))] pub fn into_cstring(self, vm: &VirtualMachine) -> PyResult { - ffi::CString::new(self.into_bytes()) - .map_err(|_| vm.new_value_error("embedded null character".to_owned())) + ffi::CString::new(self.into_bytes()).map_err(|err| err.into_pyexception(vm)) } #[cfg(windows)] pub fn to_widecstring(&self, vm: &VirtualMachine) -> PyResult { - widestring::WideCString::from_os_str(&self.path) - .map_err(|_| vm.new_value_error("embedded null character".to_owned())) + widestring::WideCString::from_os_str(&self.path).map_err(|err| err.into_pyexception(vm)) } } @@ -124,38 +122,73 @@ impl AsRef for PyPathLike { } } -fn fspath(obj: PyObjectRef, check_for_nul: bool, vm: &VirtualMachine) -> PyResult { +pub enum FsPath { + Str(PyStrRef), + Bytes(PyBytesRef), +} + +impl FsPath { + pub fn as_os_str(&self, vm: &VirtualMachine) -> PyResult<&ffi::OsStr> { + // TODO: FS encodings + match self { + FsPath::Str(s) => Ok(s.as_str().as_ref()), + FsPath::Bytes(b) => bytes_as_osstr(b.as_bytes(), vm), + } + } + fn to_output_mode(&self) -> OutputMode { + match self { + Self::Str(_) => OutputMode::String, + Self::Bytes(_) => OutputMode::Bytes, + } + } + pub(crate) fn as_bytes(&self) -> &[u8] { + // TODO: FS encodings + match self { + FsPath::Str(s) => s.as_str().as_bytes(), + FsPath::Bytes(b) => b.as_bytes(), + } + } +} + +impl IntoPyObject for FsPath { + fn into_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef { + match self { + Self::Str(s) => s.into_object(), + Self::Bytes(b) => b.into_object(), + } + } +} + +pub(crate) fn fspath( + obj: PyObjectRef, + check_for_nul: bool, + vm: &VirtualMachine, +) -> PyResult { let check_nul = |b: &[u8]| { if !check_for_nul || memchr::memchr(b'\0', b).is_none() { Ok(()) } else { - Err(vm.new_value_error("embedded null character".to_owned())) + Err(crate::exceptions::cstring_error(vm)) } }; - let match1 = |obj: &PyObjectRef| { + let match1 = |obj: PyObjectRef| { let pathlike = match_class!(match obj { - ref s @ PyStr => { - let s = s.as_str(); - check_nul(s.as_bytes())?; - PyPathLike { - path: s.into(), - mode: OutputMode::String, - } + s @ PyStr => { + check_nul(s.as_str().as_bytes())?; + FsPath::Str(s) } - ref b @ PyBytes => { + b @ PyBytes => { check_nul(&b)?; - PyPathLike { - path: bytes_as_osstr(&b, vm)?.to_os_string().into(), - mode: OutputMode::Bytes, - } + FsPath::Bytes(b) } - _ => return Ok(None), + obj => return Ok(Err(obj)), }); - Ok(Some(pathlike)) + Ok(Ok(pathlike)) + }; + let obj = match match1(obj)? { + Ok(pathlike) => return Ok(pathlike), + Err(obj) => obj, }; - if let Some(pathlike) = match1(&obj)? { - return Ok(pathlike); - } let method = vm.get_method_or_type_error(obj.clone(), "__fspath__", || { format!( "expected str, bytes or os.PathLike object, not '{}'", @@ -163,7 +196,7 @@ fn fspath(obj: PyObjectRef, check_for_nul: bool, vm: &VirtualMachine) -> PyResul ) })?; let result = vm.invoke(&method, ())?; - match1(&result)?.ok_or_else(|| { + match1(result)?.map_err(|result| { vm.new_type_error(format!( "expected {}.__fspath__() to return str or bytes, not '{}'", obj.class().name, @@ -174,7 +207,11 @@ fn fspath(obj: PyObjectRef, check_for_nul: bool, vm: &VirtualMachine) -> PyResul impl TryFromObject for PyPathLike { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - fspath(obj, true, vm) + let path = fspath(obj, true, vm)?; + Ok(Self { + path: path.as_os_str(vm)?.to_owned().into(), + mode: path.to_output_mode(), + }) } } @@ -848,7 +885,7 @@ mod _os { FollowSymlinks(false), ) .map_err(|e| e.into_pyexception(vm))? - .ok_or_else(|| vm.new_value_error("embedded null character".to_owned()))?; + .ok_or_else(|| crate::exceptions::cstring_error(vm))?; // Err(T) means other thread set `ino` at the mean time which is safe to ignore let _ = self.ino.compare_exchange(None, Some(stat.st_ino)); Ok(stat.st_ino) @@ -1167,7 +1204,7 @@ mod _os { ) -> PyResult { let stat = stat_inner(file, dir_fd, follow_symlinks) .map_err(|e| e.into_pyexception(vm))? - .ok_or_else(|| vm.new_value_error("embedded null character".to_owned()))?; + .ok_or_else(|| crate::exceptions::cstring_error(vm))?; Ok(StatResult::from_stat(&stat).into_pyobject(vm)) } @@ -1208,9 +1245,8 @@ mod _os { } #[pyfunction] - fn fspath(path: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let path = super::fspath(path, false, vm)?; - path.mode.process_path(path.path, vm) + fn fspath(path: PyObjectRef, vm: &VirtualMachine) -> PyResult { + super::fspath(path, false, vm) } #[pyfunction] @@ -2155,11 +2191,8 @@ mod posix { } #[pyfunction] - fn system(command: PyStrRef) -> PyResult { - use std::ffi::CString; - - let rstr = command.as_str(); - let cstr = CString::new(rstr).unwrap(); + fn system(command: PyStrRef, vm: &VirtualMachine) -> PyResult { + let cstr = command.to_cstring(vm)?; let x = unsafe { libc::system(cstr.as_ptr()) }; Ok(x) } @@ -2189,10 +2222,11 @@ mod posix { argv: Either, vm: &VirtualMachine, ) -> PyResult<()> { - let path = ffi::CString::new(path.as_str()) - .map_err(|_| vm.new_value_error("embedded null character".to_owned()))?; + let path = path.to_cstring(vm)?; - let argv: Vec = vm.extract_elements(argv.as_object())?; + let argv = vm.extract_elements_func(argv.as_object(), |obj| { + PyStrRef::try_from_object(vm, obj)?.to_cstring(vm) + })?; let argv: Vec<&ffi::CStr> = argv.iter().map(|entry| entry.as_c_str()).collect(); let first = argv @@ -2216,10 +2250,11 @@ mod posix { env: PyDictRef, vm: &VirtualMachine, ) -> PyResult<()> { - let path = ffi::CString::new(path.into_bytes()) - .map_err(|_| vm.new_value_error("embedded null character".to_owned()))?; + let path = path.into_cstring(vm)?; - let argv: Vec = vm.extract_elements(argv.as_object())?; + let argv = vm.extract_elements_func(argv.as_object(), |obj| { + PyStrRef::try_from_object(vm, obj)?.to_cstring(vm) + })?; let argv: Vec<&ffi::CStr> = argv.iter().map(|entry| entry.as_c_str()).collect(); let first = argv @@ -2236,16 +2271,19 @@ mod posix { .into_iter() .map(|(k, v)| -> PyResult<_> { let (key, value) = ( - PyPathLike::try_from_object(&vm, k)?, - PyPathLike::try_from_object(&vm, v)?, + PyPathLike::try_from_object(&vm, k)?.into_bytes(), + PyPathLike::try_from_object(&vm, v)?.into_bytes(), ); - if key.path.display().to_string().contains('=') { + if memchr::memchr(b'=', &key).is_some() { return Err(vm.new_value_error("illegal environment variable name".to_owned())); } - ffi::CString::new(format!("{}={}", key.path.display(), value.path.display())) - .map_err(|_| vm.new_value_error("embedded null character".to_owned())) + let mut entry = key; + entry.push(b'='); + entry.extend_from_slice(&value); + + ffi::CString::new(entry).map_err(|err| err.into_pyexception(vm)) }) .collect::, _>>()?; @@ -3164,16 +3202,17 @@ mod nt { vm: &VirtualMachine, ) -> PyResult<()> { use std::iter::once; - use std::os::windows::prelude::*; - use std::str::FromStr; - let path: Vec = ffi::OsString::from_str(path.as_str()) - .unwrap() - .encode_wide() - .chain(once(0u16)) - .collect(); + let make_widestring = |s: &str| { + widestring::WideCString::from_os_str(s).map_err(|err| err.into_pyexception(vm)) + }; - let argv: Vec = vm.extract_elements(argv.as_object())?; + let path = make_widestring(path.as_str())?; + + let argv = vm.extract_elements_func(argv.as_object(), |obj| { + let arg = PyStrRef::try_from_object(vm, obj)?; + make_widestring(arg.as_str()) + })?; let first = argv .first() @@ -3185,11 +3224,6 @@ mod nt { ); } - let argv: Vec> = argv - .into_iter() - .map(|s| s.encode_wide().chain(once(0u16)).collect()) - .collect(); - let argv_execv: Vec<*const u16> = argv .iter() .map(|v| v.as_ptr()) diff --git a/vm/src/stdlib/posixsubprocess.rs b/vm/src/stdlib/posixsubprocess.rs index fdae86970..2707d709f 100644 --- a/vm/src/stdlib/posixsubprocess.rs +++ b/vm/src/stdlib/posixsubprocess.rs @@ -54,9 +54,7 @@ struct CStrPathLike { } impl TryFromObject for CStrPathLike { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let s = os::PyPathLike::try_from_object(vm, obj)?.into_bytes(); - let s = CString::new(s) - .map_err(|_| vm.new_value_error("embedded null character".to_owned()))?; + let s = os::PyPathLike::try_from_object(vm, obj)?.into_cstring(vm)?; Ok(CStrPathLike { s }) } } diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 03b81aae9..391502c01 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -22,7 +22,7 @@ use openssl::{ x509::{self, X509Object, X509Ref, X509}, }; use std::convert::TryFrom; -use std::ffi::{CStr, CString}; +use std::ffi::CStr; use std::fmt; use std::time::Instant; @@ -177,17 +177,15 @@ fn _ssl_enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult #[derive(FromArgs)] struct Txt2ObjArgs { #[pyarg(any)] - txt: CString, + txt: PyStrRef, #[pyarg(any, default = "false")] name: bool, } fn _ssl_txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult { - txt2obj(&args.txt, !args.name) + txt2obj(&args.txt.to_cstring(vm)?, !args.name) .as_deref() .map(obj2py) - .ok_or_else(|| { - vm.new_value_error(format!("unknown object '{}'", args.txt.to_str().unwrap())) - }) + .ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt))) } fn _ssl_nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult { @@ -344,7 +342,7 @@ impl PySslContext { fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { let ciphers = cipherlist.as_str(); if ciphers.contains('\0') { - return Err(vm.new_value_error("embedded null character".to_owned())); + return Err(crate::exceptions::cstring_error(vm)); } self.builder().set_cipher_list(ciphers).map_err(|_| { vm.new_exception_msg(ssl_error(vm), "No cipher can be selected.".to_owned()) @@ -478,14 +476,16 @@ impl PySslContext { } if args.cafile.is_some() || args.capath.is_some() { + let cafile = args.cafile.map(|s| s.to_cstring(vm)).transpose()?; + let capath = args.capath.map(|s| s.to_cstring(vm)).transpose()?; let ret = unsafe { let ctx = self.ctx.write(); sys::SSL_CTX_load_verify_locations( ctx.as_ptr(), - args.cafile + cafile .as_ref() .map_or_else(std::ptr::null, |cs| cs.as_ptr()), - args.capath + capath .as_ref() .map_or_else(std::ptr::null, |cs| cs.as_ptr()), ) @@ -624,9 +624,9 @@ struct WrapSocketArgs { #[derive(FromArgs)] struct LoadVerifyLocationsArgs { #[pyarg(any, default)] - cafile: Option, + cafile: Option, #[pyarg(any, default)] - capath: Option, + capath: Option, #[pyarg(any, default)] cadata: Option>, } diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 0439a25c3..13ee546ac 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -12,6 +12,7 @@ use winapi::um::{ use super::os::errno_err; use crate::builtins::dict::{PyDictRef, PyMapping}; use crate::builtins::pystr::PyStrRef; +use crate::exceptions::IntoPyException; use crate::function::OptionalArg; use crate::VirtualMachine; use crate::{PyObjectRef, PyResult, PySequence, TryFromObject}; @@ -170,14 +171,9 @@ fn _winapi_CreateProcess( .map_or_else(null_mut, |l| l.attrlist.as_mut_ptr() as _); let wstr = |s: PyStrRef| { - if s.as_str().contains('\0') { - Err(vm.new_value_error("embedded null character".to_owned())) - } else { - Ok(s.as_str() - .encode_utf16() - .chain(std::iter::once(0)) - .collect::>()) - } + let ws = widestring::WideCString::from_str(s.as_str()) + .map_err(|err| err.into_pyexception(vm))?; + Ok(ws.into_vec_with_nul()) }; let app_name = args.name.map(wstr).transpose()?; @@ -224,25 +220,25 @@ fn _winapi_CreateProcess( } fn getenvironment(env: PyDictRef, vm: &VirtualMachine) -> PyResult> { - let mut out = vec![]; + let mut out = widestring::WideString::new(); for (k, v) in env { let k = PyStrRef::try_from_object(vm, k)?; let k = k.as_str(); let v = PyStrRef::try_from_object(vm, v)?; let v = v.as_str(); if k.contains('\0') || v.contains('\0') { - return Err(vm.new_value_error("embedded null character".to_owned())); + return Err(crate::exceptions::cstring_error(vm)); } - if k.len() == 0 || k[1..].contains('=') { + if k.is_empty() || k[1..].contains('=') { return Err(vm.new_value_error("illegal environment variable name".to_owned())); } - out.extend(k.encode_utf16()); - out.push(b'=' as u16); - out.extend(v.encode_utf16()); - out.push(b'\0' as u16); + out.push_str(k); + out.push_str("="); + out.push_str(v); + out.push_str("\0"); } - out.push(b'\0' as u16); - Ok(out) + out.push_str("\0"); + Ok(out.into_vec()) } struct AttrList {