mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Merge pull request #1380 from mpajkowski/os_scandir_contextmanager
os.scandir refinement
This commit is contained in:
@@ -63,27 +63,30 @@ assert os.fspath(b"Testing") == b"Testing"
|
||||
assert_raises(TypeError, lambda: os.fspath([1,2,3]))
|
||||
|
||||
class TestWithTempDir():
|
||||
def __enter__(self):
|
||||
if os.name == "nt":
|
||||
base_folder = os.environ["TEMP"]
|
||||
else:
|
||||
base_folder = "/tmp"
|
||||
name = os.path.join(base_folder, "rustpython_test_os_" + str(int(time.time())))
|
||||
os.mkdir(name)
|
||||
self.name = name
|
||||
return name
|
||||
def __enter__(self):
|
||||
if os.name == "nt":
|
||||
base_folder = os.environ["TEMP"]
|
||||
else:
|
||||
base_folder = "/tmp"
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# TODO: Delete temp dir
|
||||
pass
|
||||
name = os.path.join(base_folder, "rustpython_test_os_" + str(int(time.time())))
|
||||
|
||||
while os.path.isdir(name):
|
||||
name = name + "_"
|
||||
|
||||
os.mkdir(name)
|
||||
self.name = name
|
||||
return name
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
class TestWithTempCurrentDir():
|
||||
def __enter__(self):
|
||||
self.prev_cwd = os.getcwd()
|
||||
def __enter__(self):
|
||||
self.prev_cwd = os.getcwd()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.chdir(self.prev_cwd)
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.chdir(self.prev_cwd)
|
||||
|
||||
|
||||
FILE_NAME = "test1"
|
||||
@@ -286,3 +289,27 @@ if "win" not in sys.platform:
|
||||
assert_raises(OSError, lambda: os.ttyname(9999))
|
||||
os.close(b)
|
||||
os.close(a)
|
||||
|
||||
with TestWithTempDir() as tmpdir:
|
||||
for i in range(0, 4):
|
||||
file_name = os.path.join(tmpdir, 'file' + str(i))
|
||||
with open(file_name, 'w') as f:
|
||||
f.write('test')
|
||||
|
||||
expected_files = ['file0', 'file1', 'file2', 'file3']
|
||||
|
||||
dir_iter = os.scandir(tmpdir)
|
||||
collected_files = [dir_entry.name for dir_entry in dir_iter]
|
||||
|
||||
assert set(collected_files) == set(expected_files)
|
||||
|
||||
with assert_raises(StopIteration):
|
||||
next(dir_iter)
|
||||
|
||||
dir_iter.close()
|
||||
|
||||
with TestWithTempCurrentDir():
|
||||
os.chdir(tmpdir)
|
||||
with os.scandir() as dir_iter:
|
||||
collected_files = [dir_entry.name for dir_entry in dir_iter]
|
||||
assert set(collected_files) == set(expected_files)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use num_cpus;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::ffi;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
@@ -558,6 +558,7 @@ impl DirEntryRef {
|
||||
#[derive(Debug)]
|
||||
struct ScandirIterator {
|
||||
entries: RefCell<fs::ReadDir>,
|
||||
exhausted: Cell<bool>,
|
||||
}
|
||||
|
||||
impl PyValue for ScandirIterator {
|
||||
@@ -570,25 +571,53 @@ impl PyValue for ScandirIterator {
|
||||
impl ScandirIterator {
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
if self.exhausted.get() {
|
||||
return Err(objiter::new_stop_iteration(vm));
|
||||
}
|
||||
|
||||
match self.entries.borrow_mut().next() {
|
||||
Some(entry) => match entry {
|
||||
Ok(entry) => Ok(DirEntry { entry }.into_ref(vm).into_object()),
|
||||
Err(s) => Err(convert_io_error(vm, s)),
|
||||
},
|
||||
None => Err(objiter::new_stop_iteration(vm)),
|
||||
None => {
|
||||
self.exhausted.set(true);
|
||||
Err(objiter::new_stop_iteration(vm))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn close(&self, _vm: &VirtualMachine) {
|
||||
self.exhausted.set(true);
|
||||
}
|
||||
|
||||
#[pymethod(name = "__iter__")]
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__enter__")]
|
||||
fn enter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__exit__")]
|
||||
fn exit(zelf: PyRef<Self>, _args: PyFuncArgs, vm: &VirtualMachine) {
|
||||
zelf.close(vm)
|
||||
}
|
||||
}
|
||||
|
||||
fn os_scandir(path: PyStringRef, vm: &VirtualMachine) -> PyResult {
|
||||
match fs::read_dir(path.as_str()) {
|
||||
fn os_scandir(path: OptionalArg<PyStringRef>, vm: &VirtualMachine) -> PyResult {
|
||||
let path = match path {
|
||||
OptionalArg::Present(ref path) => path.as_str(),
|
||||
OptionalArg::Missing => ".",
|
||||
};
|
||||
|
||||
match fs::read_dir(path) {
|
||||
Ok(iter) => Ok(ScandirIterator {
|
||||
entries: RefCell::new(iter),
|
||||
exhausted: Cell::new(false),
|
||||
}
|
||||
.into_ref(vm)
|
||||
.into_object()),
|
||||
|
||||
Reference in New Issue
Block a user