diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index e353532fa..3d9297dd4 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1069,6 +1069,15 @@ pub enum Either { B(B), } +impl Either, PyRef> { + pub fn into_object(self) -> PyObjectRef { + match self { + Either::A(a) => a.into_object(), + Either::B(b) => b.into_object(), + } + } +} + /// This allows a builtin method to accept arguments that may be one of two /// types, raising a `TypeError` if it is neither. /// diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 82b06ba11..7d4151dbf 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -6,7 +6,6 @@ use std::io::prelude::*; use std::io::Cursor; use std::io::SeekFrom; -use num_bigint::ToBigInt; use num_traits::ToPrimitive; use super::os; @@ -14,12 +13,12 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::obj::objbytearray::PyByteArray; use crate::obj::objbytes; use crate::obj::objbytes::PyBytes; -use crate::obj::objint; -use crate::obj::objstr; +use crate::obj::objint::{self, PyIntRef}; +use crate::obj::objstr::{self, PyStringRef}; use crate::obj::objtype; use crate::obj::objtype::PyClassRef; use crate::pyobject::TypeProtocol; -use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{BufferProtocol, Either, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; fn byte_count(bytes: OptionalArg>) -> i64 { @@ -102,7 +101,7 @@ impl PyValue for PyStringIO { impl PyStringIORef { //write string to underlying vector - fn write(self, data: objstr::PyStringRef, vm: &VirtualMachine) -> PyResult { + fn write(self, data: PyStringRef, vm: &VirtualMachine) -> PyResult { let bytes = &data.value.clone().into_bytes(); match self.buffer.borrow_mut().write(bytes.to_vec()) { @@ -312,32 +311,34 @@ fn compute_c_flag(mode: &str) -> u32 { flag as u32 } -fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(file_io, None), (name, None)], - optional = [(mode, Some(vm.ctx.str_type()))] - ); - - let file_no = if objtype::isinstance(&name, &vm.ctx.str_type()) { - let rust_mode = mode.map_or("r".to_string(), objstr::get_value); - let args = vec![ - name.clone(), - vm.ctx - .new_int(compute_c_flag(&rust_mode).to_bigint().unwrap()), - ]; - os::os_open(vm, PyFuncArgs::new(args, vec![]))? - } else if objtype::isinstance(&name, &vm.ctx.int_type()) { - name.clone() - } else { - return Err(vm.new_type_error("name parameter must be string or int".to_string())); +fn file_io_init( + file_io: PyObjectRef, + name: Either, + mode: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let file_no = match &name { + Either::A(name) => { + let mode = match mode { + OptionalArg::Present(mode) => compute_c_flag(mode.as_str()), + OptionalArg::Missing => libc::O_RDONLY as _, + }; + let fno = os::os_open( + name.clone(), + mode as _, + OptionalArg::Missing, + OptionalArg::Missing, + vm, + )?; + vm.new_int(fno) + } + Either::B(fno) => fno.clone().into_object(), }; - vm.set_attr(file_io, "name", name.clone())?; - vm.set_attr(file_io, "fileno", file_no)?; - vm.set_attr(file_io, "closefd", vm.new_bool(false))?; - vm.set_attr(file_io, "closed", vm.new_bool(false))?; + vm.set_attr(&file_io, "name", name.into_object())?; + vm.set_attr(&file_io, "fileno", file_no)?; + vm.set_attr(&file_io, "closefd", vm.new_bool(false))?; + vm.set_attr(&file_io, "closed", vm.new_bool(false))?; Ok(vm.get_none()) } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 291c71146..69c177dc4 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -19,13 +19,13 @@ use nix::pty::openpty; use nix::unistd::{self, Gid, Pid, Uid, Whence}; use num_traits::cast::ToPrimitive; -use crate::function::{IntoPyNativeFunc, PyFuncArgs}; +use crate::function::{IntoPyNativeFunc, OptionalArg, PyFuncArgs}; use crate::obj::objbytes::PyBytesRef; use crate::obj::objdict::PyDictRef; -use crate::obj::objint::{self, PyInt, PyIntRef}; +use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; use crate::obj::objset::PySet; -use crate::obj::objstr::{self, PyString, PyStringRef}; +use crate::obj::objstr::{self, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ Either, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, @@ -95,59 +95,58 @@ pub fn os_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +#[cfg(unix)] +type OpenFlags = i32; +#[cfg(windows)] +type OpenFlags = u32; + #[cfg(any(unix, windows))] -pub fn os_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [ - (name, Some(vm.ctx.str_type())), - (flags, Some(vm.ctx.int_type())) - ], - optional = [ - (_mode, Some(vm.ctx.int_type())), - (dir_fd, Some(vm.ctx.int_type())) - ] - ); - - let name = name.clone().downcast::().unwrap(); - let dir_fd = if let Some(obj) = dir_fd { - DirFd { - dir_fd: Some(obj.clone().downcast::().unwrap()), - } - } else { - DirFd::default() +pub fn os_open( + name: PyStringRef, + flags: OpenFlags, + _mode: OptionalArg, + dir_fd: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let dir_fd = DirFd { + dir_fd: dir_fd.into_option(), }; - let fname = &make_path(vm, name, &dir_fd).value; + let fname = make_path(vm, name, &dir_fd); - let options = _set_file_model(&flags); + let mut options = OpenOptions::new(); + + macro_rules! bit_contains { + ($c:expr) => { + flags & $c as OpenFlags == $c as OpenFlags + }; + } + + if bit_contains!(libc::O_RDWR) { + options.read(true).write(true); + } else if bit_contains!(libc::O_WRONLY) { + options.write(true); + } else if bit_contains!(libc::O_RDONLY) { + options.read(true); + } + + if bit_contains!(libc::O_APPEND) { + options.append(true); + } + + if bit_contains!(libc::O_CREAT) { + if bit_contains!(libc::O_EXCL) { + options.create_new(true); + } else { + options.create(true); + } + } + + options.custom_flags(flags); let handle = options - .open(&fname) + .open(fname.as_str()) .map_err(|err| convert_io_error(vm, err))?; - Ok(vm.ctx.new_int(raw_file_number(handle))) -} - -#[cfg(unix)] -fn _set_file_model(flags: &PyObjectRef) -> OpenOptions { - let flags = objint::get_value(flags).to_i32().unwrap(); - let mut options = OpenOptions::new(); - options.read(flags == libc::O_RDONLY); - options.write(flags & libc::O_WRONLY != 0); - options.append(flags & libc::O_APPEND != 0); - options.custom_flags(flags); - options -} - -#[cfg(windows)] -fn _set_file_model(flags: &PyObjectRef) -> OpenOptions { - let flags = objint::get_value(flags).to_u32().unwrap(); - let mut options = OpenOptions::new(); - options.read((flags as i32) == libc::O_RDONLY); - options.write((flags as i32) & libc::O_WRONLY != 0); - options.append((flags as i32) & libc::O_APPEND != 0); - options.custom_flags(flags); - options + Ok(raw_file_number(handle)) } #[cfg(all(not(unix), not(windows)))]