diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 0c1f240d79..e890929657 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -26,7 +26,8 @@ use crate::obj::objset::PySet; use crate::obj::objstr::{self, PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ - ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, + Either, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -416,12 +417,46 @@ fn os_listdir(path: PyStringRef, vm: &VirtualMachine) -> PyResult { } } -fn os_putenv(key: PyStringRef, value: PyStringRef, _vm: &VirtualMachine) { - env::set_var(&key.value, &value.value) +fn bytes_as_osstr<'a>(b: &'a [u8], vm: &VirtualMachine) -> PyResult<&'a ffi::OsStr> { + let os_str = { + #[cfg(unix)] + { + use std::os::unix::ffi::OsStrExt; + Some(ffi::OsStr::from_bytes(b)) + } + #[cfg(windows)] + { + std::str::from_utf8(b).ok().map(|s| s.as_ref()) + } + }; + os_str + .ok_or_else(|| vm.new_value_error("Can't convert bytes to str for env function".to_owned())) } -fn os_unsetenv(key: PyStringRef, _vm: &VirtualMachine) { - env::remove_var(&key.value) +fn os_putenv( + key: Either, + value: Either, + vm: &VirtualMachine, +) -> PyResult<()> { + let key: &ffi::OsStr = match key { + Either::A(ref s) => s.as_str().as_ref(), + Either::B(ref b) => bytes_as_osstr(b.get_value(), vm)?, + }; + let value: &ffi::OsStr = match value { + Either::A(ref s) => s.as_str().as_ref(), + Either::B(ref b) => bytes_as_osstr(b.get_value(), vm)?, + }; + env::set_var(key, value); + Ok(()) +} + +fn os_unsetenv(key: Either, vm: &VirtualMachine) -> PyResult<()> { + let key: &ffi::OsStr = match key { + Either::A(ref s) => s.as_str().as_ref(), + Either::B(ref b) => bytes_as_osstr(b.get_value(), vm)?, + }; + env::remove_var(key); + Ok(()) } fn _os_environ(vm: &VirtualMachine) -> PyDictRef {