diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 3d824b2ae..4f10db8be 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -113,9 +113,9 @@ schannel = "0.1.19" widestring = "0.5.1" [target.'cfg(windows)'.dependencies.windows] -version = "0.39" +version = "0.39.0" features = [ - "Win32_UI_Shell", + "Win32_UI_Shell", "Win32_System_LibraryLoader", "Win32_Foundation" ] [target.'cfg(windows)'.dependencies.winapi] diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index c28faf43b..1a38a2867 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -10,12 +10,19 @@ mod _winapi { stdlib::os::errno_err, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; + use std::ffi::{OsStr, OsString}; + use std::os::windows::prelude::*; use std::ptr::{null, null_mut}; use winapi::shared::winerror; use winapi::um::{ fileapi, handleapi, namedpipeapi, processenv, processthreadsapi, synchapi, winbase, winnt::HANDLE, }; + use windows::{ + core::PCWSTR, + Win32::Foundation::{HINSTANCE, MAX_PATH}, + Win32::System::LibraryLoader::{GetModuleFileNameW, LoadLibraryW}, + }; #[pyattr] use winapi::{ @@ -402,4 +409,56 @@ mod _winapi { }) .map(drop) } + + pub trait ToWideString { + fn to_wide(&self) -> Vec; + fn to_wides_with_nul(&self) -> Vec; + } + impl ToWide for T + where + T: AsRef, + { + fn to_wide(&self) -> Vec { + self.as_ref().encode_wide().collect() + } + fn to_wide_null(&self) -> Vec { + self.as_ref().encode_wide().chain(Some(0)).collect() + } + } + pub trait FromWide + where + Self: Sized, + { + fn from_wides_until_nul(wide: &[u16]) -> Self; + } + impl FromWide for OsString { + fn from_wide_null(wide: &[u16]) -> OsString { + let len = wide.iter().take_while(|&&c| c != 0).count(); + OsString::from_wide(&wide[..len]) + } + } + + #[pyfunction] + fn LoadLibrary(path: PyStrRef, vm: &VirtualMachine) -> PyResult { + let path = path.as_str().to_wide_null(); + let handle = unsafe { LoadLibraryW(PCWSTR::from_raw(path.as_ptr())).unwrap() }; + if handle.is_invalid() { + return Err(vm.new_runtime_error("LoadLibrary failed".to_owned())); + } + Ok(handle.0) + } + + #[pyfunction] + fn GetModuleFileName(handle: isize, vm: &VirtualMachine) -> PyResult { + let mut path: Vec = vec![0; MAX_PATH as usize]; + let handle = HINSTANCE(handle); + + let length = unsafe { GetModuleFileNameW(handle, &mut path) }; + if length == 0 { + return Err(vm.new_runtime_error("GetModuleFileName failed".to_owned())); + } + + let (path, _) = path.split_at(length as usize); + Ok(String::from_utf16(&path).unwrap()) + } }