Make Gid/Uid less janky

This commit is contained in:
Noa
2024-09-18 17:17:09 -05:00
parent b5c1fd95dc
commit 8152e7e62c

View File

@@ -1111,17 +1111,13 @@ pub mod module {
}
#[pyfunction]
fn setgid(gid: Option<Gid>, vm: &VirtualMachine) -> PyResult<()> {
let gid =
gid.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
fn setgid(gid: Gid, vm: &VirtualMachine) -> PyResult<()> {
unistd::setgid(gid).map_err(|err| err.into_pyexception(vm))
}
#[cfg(not(target_os = "redox"))]
#[pyfunction]
fn setegid(egid: Option<Gid>, vm: &VirtualMachine) -> PyResult<()> {
let egid =
egid.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
fn setegid(egid: Gid, vm: &VirtualMachine) -> PyResult<()> {
unistd::setegid(egid).map_err(|err| err.into_pyexception(vm))
}
@@ -1139,7 +1135,7 @@ pub mod module {
.map_err(|err| err.into_pyexception(vm))
}
fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult<Option<u32>> {
fn try_from_id(vm: &VirtualMachine, obj: PyObjectRef, typ_name: &str) -> PyResult<u32> {
use std::cmp::Ordering;
let i = obj
.try_to_ref::<PyInt>(vm)
@@ -1152,53 +1148,47 @@ pub mod module {
.try_to_primitive::<i64>(vm)?;
match i.cmp(&-1) {
Ordering::Greater => Ok(Some(i.try_into().map_err(|_| {
Ordering::Greater => Ok(i.try_into().map_err(|_| {
vm.new_overflow_error(format!("{typ_name} is larger than maximum"))
})?)),
})?),
Ordering::Less => {
Err(vm.new_overflow_error(format!("{typ_name} is less than minimum")))
}
Ordering::Equal => Ok(None), // -1 means does not change the value
// -1 means does not change the value
// In CPython, this is `(uid_t) -1`, rustc gets mad when we try to declare
// a negative unsigned integer :).
Ordering::Equal => Ok(-1i32 as u32),
}
}
impl TryFromObject for Option<Uid> {
impl TryFromObject for Uid {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
Ok(try_from_id(vm, obj, "uid")?.map(Uid::from_raw))
try_from_id(vm, obj, "uid").map(Uid::from_raw)
}
}
impl TryFromObject for Option<Gid> {
impl TryFromObject for Gid {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
Ok(try_from_id(vm, obj, "gid")?.map(Gid::from_raw))
try_from_id(vm, obj, "gid").map(Gid::from_raw)
}
}
#[pyfunction]
fn setuid(uid: Option<Uid>, vm: &VirtualMachine) -> PyResult<()> {
let uid =
uid.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
unistd::setuid(uid).map_err(|err| err.into_pyexception(vm))
fn setuid(uid: Uid) -> nix::Result<()> {
unistd::setuid(uid)
}
#[cfg(not(target_os = "redox"))]
#[pyfunction]
fn seteuid(euid: Option<Uid>, vm: &VirtualMachine) -> PyResult<()> {
let euid =
euid.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
unistd::seteuid(euid).map_err(|err| err.into_pyexception(vm))
fn seteuid(euid: Uid) -> nix::Result<()> {
unistd::seteuid(euid)
}
#[cfg(not(target_os = "redox"))]
#[pyfunction]
fn setreuid(ruid: Option<Uid>, euid: Option<Uid>, vm: &VirtualMachine) -> PyResult<()> {
if let Some(ruid) = ruid {
unistd::setuid(ruid).map_err(|err| err.into_pyexception(vm))?;
}
if let Some(euid) = euid {
unistd::seteuid(euid).map_err(|err| err.into_pyexception(vm))?;
}
Ok(())
fn setreuid(ruid: Uid, euid: Uid) -> nix::Result<()> {
let ret = unsafe { libc::setreuid(ruid.as_raw(), euid.as_raw()) };
nix::Error::result(ret).map(drop)
}
// cfg from nix
@@ -1209,20 +1199,8 @@ pub mod module {
target_os = "openbsd"
))]
#[pyfunction]
fn setresuid(
ruid: Option<Uid>,
euid: Option<Uid>,
suid: Option<Uid>,
vm: &VirtualMachine,
) -> PyResult<()> {
let unwrap_or_unchanged =
|u: Option<Uid>| u.unwrap_or_else(|| Uid::from_raw(libc::uid_t::MAX));
unistd::setresuid(
unwrap_or_unchanged(ruid),
unwrap_or_unchanged(euid),
unwrap_or_unchanged(suid),
)
.map_err(|err| err.into_pyexception(vm))
fn setresuid(ruid: Uid, euid: Uid, suid: Uid) -> nix::Result<()> {
unistd::setresuid(ruid, euid, suid)
}
#[cfg(not(target_os = "redox"))]
@@ -1271,8 +1249,8 @@ pub mod module {
// cfg from nix
#[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))]
#[pyfunction]
fn getresuid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> {
let ret = unistd::getresuid().map_err(|e| e.into_pyexception(vm))?;
fn getresuid() -> nix::Result<(u32, u32, u32)> {
let ret = unistd::getresuid()?;
Ok((
ret.real.as_raw(),
ret.effective.as_raw(),
@@ -1283,8 +1261,8 @@ pub mod module {
// cfg from nix
#[cfg(any(target_os = "android", target_os = "linux", target_os = "openbsd"))]
#[pyfunction]
fn getresgid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> {
let ret = unistd::getresgid().map_err(|e| e.into_pyexception(vm))?;
fn getresgid() -> nix::Result<(u32, u32, u32)> {
let ret = unistd::getresgid()?;
Ok((
ret.real.as_raw(),
ret.effective.as_raw(),
@@ -1300,32 +1278,15 @@ pub mod module {
target_os = "openbsd"
))]
#[pyfunction]
fn setresgid(
rgid: Option<Gid>,
egid: Option<Gid>,
sgid: Option<Gid>,
vm: &VirtualMachine,
) -> PyResult<()> {
let unwrap_or_unchanged =
|u: Option<Gid>| u.unwrap_or_else(|| Gid::from_raw(libc::gid_t::MAX));
unistd::setresgid(
unwrap_or_unchanged(rgid),
unwrap_or_unchanged(egid),
unwrap_or_unchanged(sgid),
)
.map_err(|err| err.into_pyexception(vm))
fn setresgid(rgid: Gid, egid: Gid, sgid: Gid, vm: &VirtualMachine) -> PyResult<()> {
unistd::setresgid(rgid, egid, sgid).map_err(|err| err.into_pyexception(vm))
}
#[cfg(not(target_os = "redox"))]
#[pyfunction]
fn setregid(rgid: Option<Gid>, egid: Option<Gid>, vm: &VirtualMachine) -> PyResult<()> {
if let Some(rgid) = rgid {
unistd::setgid(rgid).map_err(|err| err.into_pyexception(vm))?;
}
if let Some(egid) = egid {
unistd::setegid(egid).map_err(|err| err.into_pyexception(vm))?;
}
Ok(())
fn setregid(rgid: Gid, egid: Gid) -> nix::Result<()> {
let ret = unsafe { libc::setregid(rgid.as_raw(), egid.as_raw()) };
nix::Error::result(ret).map(drop)
}
// cfg from nix
@@ -1336,10 +1297,8 @@ pub mod module {
target_os = "openbsd"
))]
#[pyfunction]
fn initgroups(user_name: PyStrRef, gid: Option<Gid>, vm: &VirtualMachine) -> PyResult<()> {
fn initgroups(user_name: PyStrRef, gid: Gid, vm: &VirtualMachine) -> PyResult<()> {
let user = user_name.to_cstring(vm)?;
let gid =
gid.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
unistd::initgroups(&user, gid).map_err(|err| err.into_pyexception(vm))
}
@@ -1347,15 +1306,11 @@ pub mod module {
#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "redox")))]
#[pyfunction]
fn setgroups(
group_ids: crate::function::ArgIterable<Option<Gid>>,
group_ids: crate::function::ArgIterable<Gid>,
vm: &VirtualMachine,
) -> PyResult<()> {
let gids = group_ids
.iter(vm)?
.collect::<Result<Option<Vec<_>>, _>>()?
.ok_or_else(|| vm.new_errno_error(1, "Operation not permitted".to_string()))?;
let ret = unistd::setgroups(&gids);
ret.map_err(|err| err.into_pyexception(vm))
let gids = group_ids.iter(vm)?.collect::<Result<Vec<_>, _>>()?;
unistd::setgroups(&gids).map_err(|err| err.into_pyexception(vm))
}
#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))]