diff --git a/Cargo.lock b/Cargo.lock index 300752a8a..1b5023e23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1131,6 +1131,7 @@ dependencies = [ "hexf-parse 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "indexmap 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "lexical 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "libc 0.2.60 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/tests/snippets/stdlib_io.py b/tests/snippets/stdlib_io.py index 184146081..78f1287a9 100644 --- a/tests/snippets/stdlib_io.py +++ b/tests/snippets/stdlib_io.py @@ -1,5 +1,6 @@ from io import BufferedReader, FileIO, StringIO, BytesIO import os +from testutils import assertRaises fi = FileIO('README.md') assert fi.seekable() @@ -23,3 +24,9 @@ fd = os.open('README.md', os.O_RDONLY) with FileIO(fd) as fio: res2 = fio.read() assert res == res2 + +fi = FileIO('README.md') +fi.read() +fi.close() +with assertRaises(ValueError): + fi.read() diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 64160547a..b1f1873f2 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -79,3 +79,6 @@ libz-sys = "1.0.25" gethostname = "0.2.0" subprocess = "0.1.18" num_cpus = "1.0" + +[target."cfg(windows)".dependencies] +kernel32-sys = "0.2.2" diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 22f17aa5f..1178724f4 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -431,6 +431,31 @@ fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } } +#[cfg(windows)] +fn file_io_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + use std::os::windows::io::IntoRawHandle; + arg_check!(vm, args, required = [(file_io, None)]); + let file_no = vm.get_attribute(file_io.clone(), "fileno")?; + let raw_fd = objint::get_value(&file_no).to_i64().unwrap(); + let handle = os::rust_file(raw_fd); + let raw_handle = handle.into_raw_handle(); + unsafe { + kernel32::CloseHandle(raw_handle); + } + Ok(vm.ctx.none()) +} + +#[cfg(unix)] +fn file_io_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(file_io, None)]); + let file_no = vm.get_attribute(file_io.clone(), "fileno")?; + let raw_fd = objint::get_value(&file_no).to_i32().unwrap(); + unsafe { + libc::close(raw_fd); + } + Ok(vm.ctx.none()) +} + fn file_io_seekable(vm: &VirtualMachine, _args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bool(true)) } @@ -678,6 +703,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "read" => ctx.new_rustfunc(file_io_read), "readinto" => ctx.new_rustfunc(file_io_readinto), "write" => ctx.new_rustfunc(file_io_write), + "close" => ctx.new_rustfunc(file_io_close), "seekable" => ctx.new_rustfunc(file_io_seekable) });