diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 5b8285310..1b8d968e4 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1578,7 +1578,6 @@ class GeneralModuleTests(unittest.TestCase): # only IP addresses are allowed self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) - @unittest.expectedFailureIf(sys.platform != "darwin", "TODO: RUSTPYTHON; socket.gethostbyname_ex") @unittest.skipUnless(support.is_resource_enabled('network'), 'network is not enabled') def test_idna(self): @@ -5519,8 +5518,6 @@ class TestUnixDomain(unittest.TestCase): self.addCleanup(os_helper.unlink, path) self.assertEqual(self.sock.getsockname(), path) - # TODO: RUSTPYTHON, surrogateescape - @unittest.expectedFailure def testSurrogateescapeBind(self): # Test binding to a valid non-ASCII pathname, with the # non-ASCII bytes supplied using surrogateescape encoding. diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index 39bfde4be..17daec775 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -930,10 +930,15 @@ mod _socket { match family { #[cfg(unix)] c::AF_UNIX => { + use crate::vm::function::ArgStrOrBytesLike; use std::os::unix::ffi::OsStrExt; - let buf = crate::vm::function::ArgStrOrBytesLike::try_from_object(vm, addr)?; - let path = &*buf.borrow_bytes(); - socket2::SockAddr::unix(ffi::OsStr::from_bytes(path)) + let buf = ArgStrOrBytesLike::try_from_object(vm, addr)?; + let bytes = &*buf.borrow_bytes(); + let path = match &buf { + ArgStrOrBytesLike::Buf(_) => ffi::OsStr::from_bytes(bytes).into(), + ArgStrOrBytesLike::Str(s) => vm.fsencode(s)?, + }; + socket2::SockAddr::unix(path) .map_err(|_| vm.new_os_error("AF_UNIX path too long".to_owned()).into()) } c::AF_INET => { @@ -1704,7 +1709,7 @@ mod _socket { let path = ffi::OsStr::as_bytes(addr.as_pathname().unwrap_or("".as_ref()).as_ref()); let nul_pos = memchr::memchr(b'\0', path).unwrap_or(path.len()); let path = ffi::OsStr::from_bytes(&path[..nul_pos]); - return vm.ctx.new_str(path.to_string_lossy()).into(); + return vm.fsdecode(path).into(); } // TODO: support more address families (String::new(), 0).to_pyobject(vm) diff --git a/vm/src/function/fspath.rs b/vm/src/function/fspath.rs index ab5cd093b..e034487e1 100644 --- a/vm/src/function/fspath.rs +++ b/vm/src/function/fspath.rs @@ -5,7 +5,7 @@ use crate::{ function::PyStr, protocol::PyBuffer, }; -use std::{ffi::OsStr, path::PathBuf}; +use std::{borrow::Cow, ffi::OsStr, path::PathBuf}; #[derive(Clone)] pub enum FsPath { @@ -58,15 +58,11 @@ impl FsPath { }) } - pub fn as_os_str(&self, vm: &VirtualMachine) -> PyResult<&OsStr> { + pub fn as_os_str(&self, vm: &VirtualMachine) -> PyResult> { // TODO: FS encodings match self { - FsPath::Str(s) => { - // XXX RUSTPYTHON: this is sketchy on windows; it's not guaranteed that its - // OsStr encoding will always be compatible with WTF-8. - Ok(unsafe { OsStr::from_encoded_bytes_unchecked(s.as_wtf8().as_bytes()) }) - } - FsPath::Bytes(b) => Self::bytes_as_osstr(b.as_bytes(), vm), + FsPath::Str(s) => vm.fsencode(s), + FsPath::Bytes(b) => Self::bytes_as_osstr(b.as_bytes(), vm).map(Cow::Borrowed), } } diff --git a/vm/src/ospath.rs b/vm/src/ospath.rs index 9dda60d62..c1b185916 100644 --- a/vm/src/ospath.rs +++ b/vm/src/ospath.rs @@ -21,28 +21,14 @@ pub(super) enum OutputMode { } impl OutputMode { - pub(super) fn process_path(self, path: impl Into, vm: &VirtualMachine) -> PyResult { - fn inner(mode: OutputMode, path: PathBuf, vm: &VirtualMachine) -> PyResult { - let path_as_string = |p: PathBuf| { - p.into_os_string().into_string().map_err(|_| { - vm.new_unicode_decode_error( - "Can't convert OS path to valid UTF-8 string".into(), - ) - }) - }; + pub(super) fn process_path(self, path: impl Into, vm: &VirtualMachine) -> PyObjectRef { + fn inner(mode: OutputMode, path: PathBuf, vm: &VirtualMachine) -> PyObjectRef { match mode { - OutputMode::String => path_as_string(path).map(|s| vm.ctx.new_str(s).into()), - OutputMode::Bytes => { - #[cfg(any(unix, target_os = "wasi"))] - { - use rustpython_common::os::ffi::OsStringExt; - Ok(vm.ctx.new_bytes(path.into_os_string().into_vec()).into()) - } - #[cfg(windows)] - { - path_as_string(path).map(|s| vm.ctx.new_bytes(s.into_bytes()).into()) - } - } + OutputMode::String => vm.fsdecode(path).into(), + OutputMode::Bytes => vm + .ctx + .new_bytes(path.into_os_string().into_encoded_bytes()) + .into(), } } inner(self, path.into(), vm) @@ -59,7 +45,7 @@ impl OsPath { } pub(crate) fn from_fspath(fspath: FsPath, vm: &VirtualMachine) -> PyResult { - let path = fspath.as_os_str(vm)?.to_owned(); + let path = fspath.as_os_str(vm)?.into_owned(); let mode = match fspath { FsPath::Str(_) => OutputMode::String, FsPath::Bytes(_) => OutputMode::Bytes, @@ -88,7 +74,7 @@ impl OsPath { widestring::WideCString::from_os_str(&self.path).map_err(|err| err.to_pyexception(vm)) } - pub fn filename(&self, vm: &VirtualMachine) -> PyResult { + pub fn filename(&self, vm: &VirtualMachine) -> PyObjectRef { self.mode.process_path(self.path.clone(), vm) } } @@ -133,7 +119,7 @@ impl From for OsPathOrFd { impl OsPathOrFd { pub fn filename(&self, vm: &VirtualMachine) -> PyObjectRef { match self { - OsPathOrFd::Path(path) => path.filename(vm).unwrap_or_else(|_| vm.ctx.none()), + OsPathOrFd::Path(path) => path.filename(vm), OsPathOrFd::Fd(fd) => vm.ctx.new_int(*fd).into(), } } diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 9018ad0e3..664fe0061 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -312,7 +312,12 @@ mod _codecs { #[pyfunction] fn utf_8_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { - if args.s.is_utf8() { + if args.s.is_utf8() + || args + .errors + .as_ref() + .is_some_and(|s| s.is(identifier!(vm, surrogatepass))) + { return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); } do_codec!(utf8::encode, args, vm) diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 30d09b78a..0b680251f 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -2225,7 +2225,7 @@ mod _io { *data = None; let encoding = match args.encoding { - None if vm.state.settings.utf8_mode > 0 => PyStr::from("utf-8").into_ref(&vm.ctx), + None if vm.state.settings.utf8_mode > 0 => identifier!(vm, utf_8).to_owned(), Some(enc) if enc.as_wtf8() != "locale" => enc, _ => { // None without utf8_mode or "locale" encoding @@ -2238,7 +2238,7 @@ mod _io { let errors = args .errors - .unwrap_or_else(|| PyStr::from("strict").into_ref(&vm.ctx)); + .unwrap_or_else(|| identifier!(vm, strict).to_owned()); let has_read1 = vm.get_attribute_opt(buffer.clone(), "read1")?.is_some(); let seekable = vm.call_method(&buffer, "seekable", ())?.try_to_bool(vm)?; diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 48f1ab668..624577b5c 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -249,7 +249,7 @@ pub(crate) mod module { .as_ref() .canonicalize() .map_err(|e| e.to_pyexception(vm))?; - path.mode.process_path(real, vm) + Ok(path.mode.process_path(real, vm)) } #[pyfunction] @@ -282,7 +282,7 @@ pub(crate) mod module { } } let buffer = widestring::WideCString::from_vec_truncate(buffer); - path.mode.process_path(buffer.to_os_string(), vm) + Ok(path.mode.process_path(buffer.to_os_string(), vm)) } #[pyfunction] @@ -297,7 +297,7 @@ pub(crate) mod module { return Err(errno_err(vm)); } let buffer = widestring::WideCString::from_vec_truncate(buffer); - path.mode.process_path(buffer.to_os_string(), vm) + Ok(path.mode.process_path(buffer.to_os_string(), vm)) } #[pyfunction] diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 9b196b6d0..e1a5825b8 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -332,7 +332,7 @@ pub(super) mod _os { }; dir_iter .map(|entry| match entry { - Ok(entry_path) => path.mode.process_path(entry_path.file_name(), vm), + Ok(entry_path) => Ok(path.mode.process_path(entry_path.file_name(), vm)), Err(err) => Err(IOErrorBuilder::with_filename(&err, path.clone(), vm)), }) .collect::>()? @@ -352,22 +352,18 @@ pub(super) mod _os { let mut dir = nix::dir::Dir::from_fd(new_fd).map_err(|e| e.into_pyexception(vm))?; dir.iter() - .filter_map(|entry| { - entry - .map_err(|e| e.into_pyexception(vm)) - .and_then(|entry| { - let fname = entry.file_name().to_bytes(); - Ok(match fname { - b"." | b".." => None, - _ => Some( - OutputMode::String - .process_path(ffi::OsStr::from_bytes(fname), vm)?, - ), - }) - }) - .transpose() + .filter_map_ok(|entry| { + let fname = entry.file_name().to_bytes(); + match fname { + b"." | b".." => None, + _ => Some( + OutputMode::String + .process_path(ffi::OsStr::from_bytes(fname), vm), + ), + } }) - .collect::>()? + .collect::>() + .map_err(|e| e.into_pyexception(vm))? } } }; @@ -429,7 +425,7 @@ pub(super) mod _os { let [] = dir_fd.0; let path = fs::read_link(&path).map_err(|err| IOErrorBuilder::with_filename(&err, path, vm))?; - mode.process_path(path, vm) + Ok(mode.process_path(path, vm)) } #[pyattr] @@ -452,12 +448,12 @@ pub(super) mod _os { impl DirEntry { #[pygetset] fn name(&self, vm: &VirtualMachine) -> PyResult { - self.mode.process_path(&self.file_name, vm) + Ok(self.mode.process_path(&self.file_name, vm)) } #[pygetset] fn path(&self, vm: &VirtualMachine) -> PyResult { - self.mode.process_path(&self.pathval, vm) + Ok(self.mode.process_path(&self.pathval, vm)) } fn perform_on_metadata( @@ -908,12 +904,12 @@ pub(super) mod _os { #[pyfunction] fn getcwd(vm: &VirtualMachine) -> PyResult { - OutputMode::String.process_path(curdir_inner(vm)?, vm) + Ok(OutputMode::String.process_path(curdir_inner(vm)?, vm)) } #[pyfunction] fn getcwdb(vm: &VirtualMachine) -> PyResult { - OutputMode::Bytes.process_path(curdir_inner(vm)?, vm) + Ok(OutputMode::Bytes.process_path(curdir_inner(vm)?, vm)) } #[pyfunction] diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 39c803a01..fdfe2faf6 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -458,21 +458,13 @@ mod sys { } #[pyfunction] - fn getfilesystemencoding(_vm: &VirtualMachine) -> String { - // TODO: implement non-utf-8 mode. - "utf-8".to_owned() + fn getfilesystemencoding(vm: &VirtualMachine) -> PyStrRef { + vm.fs_encoding().to_owned() } - #[cfg(not(windows))] #[pyfunction] - fn getfilesystemencodeerrors(_vm: &VirtualMachine) -> String { - "surrogateescape".to_owned() - } - - #[cfg(windows)] - #[pyfunction] - fn getfilesystemencodeerrors(_vm: &VirtualMachine) -> String { - "surrogatepass".to_owned() + fn getfilesystemencodeerrors(vm: &VirtualMachine) -> PyStrRef { + vm.fs_encode_errors().to_owned() } #[pyfunction] diff --git a/vm/src/vm/context.rs b/vm/src/vm/context.rs index 54605704a..a61484e6b 100644 --- a/vm/src/vm/context.rs +++ b/vm/src/vm/context.rs @@ -51,7 +51,7 @@ pub struct Context { } macro_rules! declare_const_name { - ($($name:ident,)*) => { + ($($name:ident$(: $s:literal)?,)*) => { #[derive(Debug, Clone, Copy)] #[allow(non_snake_case)] pub struct ConstName { @@ -61,11 +61,13 @@ macro_rules! declare_const_name { impl ConstName { unsafe fn new(pool: &StringPool, typ: &PyTypeRef) -> Self { Self { - $($name: unsafe { pool.intern(stringify!($name), typ.clone()) },)* + $($name: unsafe { pool.intern(declare_const_name!(@string $name $($s)?), typ.clone()) },)* } } } - } + }; + (@string $name:ident) => { stringify!($name) }; + (@string $name:ident $string:literal) => { $string }; } declare_const_name! { @@ -236,6 +238,15 @@ declare_const_name! { flush, close, WarningMessage, + strict, + ignore, + replace, + xmlcharrefreplace, + backslashreplace, + namereplace, + surrogatepass, + surrogateescape, + utf_8: "utf-8", } // Basic objects: diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index eb4e846de..9d7ecc2d5 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -41,11 +41,12 @@ use nix::{ sys::signal::{SaFlags, SigAction, SigSet, Signal::SIGINT, kill, sigaction}, unistd::getpid, }; -use std::sync::atomic::AtomicBool; use std::{ borrow::Cow, cell::{Cell, Ref, RefCell}, collections::{HashMap, HashSet}, + ffi::{OsStr, OsString}, + sync::atomic::AtomicBool, }; pub use context::Context; @@ -901,6 +902,46 @@ impl VirtualMachine { run_module_as_main.call((module,), self)?; Ok(()) } + + pub fn fs_encoding(&self) -> &'static PyStrInterned { + identifier!(self, utf_8) + } + + pub fn fs_encode_errors(&self) -> &'static PyStrInterned { + if cfg!(windows) { + identifier!(self, surrogatepass) + } else { + identifier!(self, surrogateescape) + } + } + + pub fn fsdecode(&self, s: impl Into) -> PyStrRef { + let bytes = self.ctx.new_bytes(s.into().into_encoded_bytes()); + let errors = self.fs_encode_errors().to_owned(); + self.state + .codec_registry + .decode_text(bytes.into(), "utf-8", Some(errors), self) + .unwrap() // this should never fail, since fsdecode should be lossless from the fs encoding + } + + pub fn fsencode<'a>(&self, s: &'a Py) -> PyResult> { + if cfg!(windows) || s.is_utf8() { + // XXX: this is sketchy on windows; it's not guaranteed that the + // OsStr encoding will always be compatible with WTF-8. + let s = unsafe { OsStr::from_encoded_bytes_unchecked(s.as_bytes()) }; + return Ok(Cow::Borrowed(s)); + } + let errors = self.fs_encode_errors().to_owned(); + let bytes = self + .state + .codec_registry + .encode_text(s.to_owned(), "utf-8", Some(errors), self)? + .to_vec(); + // XXX: this is sketchy on windows; it's not guaranteed that the + // OsStr encoding will always be compatible with WTF-8. + let s = unsafe { OsString::from_encoded_bytes_unchecked(bytes) }; + Ok(Cow::Owned(s)) + } } impl AsRef for VirtualMachine {