Merge pull request #1380 from mpajkowski/os_scandir_contextmanager

os.scandir refinement
This commit is contained in:
Aviv Palivoda
2019-09-30 23:53:33 +03:00
committed by GitHub
2 changed files with 76 additions and 20 deletions

View File

@@ -63,27 +63,30 @@ assert os.fspath(b"Testing") == b"Testing"
assert_raises(TypeError, lambda: os.fspath([1,2,3]))
class TestWithTempDir():
def __enter__(self):
if os.name == "nt":
base_folder = os.environ["TEMP"]
else:
base_folder = "/tmp"
name = os.path.join(base_folder, "rustpython_test_os_" + str(int(time.time())))
os.mkdir(name)
self.name = name
return name
def __enter__(self):
if os.name == "nt":
base_folder = os.environ["TEMP"]
else:
base_folder = "/tmp"
def __exit__(self, exc_type, exc_val, exc_tb):
# TODO: Delete temp dir
pass
name = os.path.join(base_folder, "rustpython_test_os_" + str(int(time.time())))
while os.path.isdir(name):
name = name + "_"
os.mkdir(name)
self.name = name
return name
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class TestWithTempCurrentDir():
def __enter__(self):
self.prev_cwd = os.getcwd()
def __enter__(self):
self.prev_cwd = os.getcwd()
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.prev_cwd)
def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.prev_cwd)
FILE_NAME = "test1"
@@ -286,3 +289,27 @@ if "win" not in sys.platform:
assert_raises(OSError, lambda: os.ttyname(9999))
os.close(b)
os.close(a)
with TestWithTempDir() as tmpdir:
for i in range(0, 4):
file_name = os.path.join(tmpdir, 'file' + str(i))
with open(file_name, 'w') as f:
f.write('test')
expected_files = ['file0', 'file1', 'file2', 'file3']
dir_iter = os.scandir(tmpdir)
collected_files = [dir_entry.name for dir_entry in dir_iter]
assert set(collected_files) == set(expected_files)
with assert_raises(StopIteration):
next(dir_iter)
dir_iter.close()
with TestWithTempCurrentDir():
os.chdir(tmpdir)
with os.scandir() as dir_iter:
collected_files = [dir_entry.name for dir_entry in dir_iter]
assert set(collected_files) == set(expected_files)

View File

@@ -1,5 +1,5 @@
use num_cpus;
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::ffi;
use std::fs::File;
use std::fs::OpenOptions;
@@ -558,6 +558,7 @@ impl DirEntryRef {
#[derive(Debug)]
struct ScandirIterator {
entries: RefCell<fs::ReadDir>,
exhausted: Cell<bool>,
}
impl PyValue for ScandirIterator {
@@ -570,25 +571,53 @@ impl PyValue for ScandirIterator {
impl ScandirIterator {
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.exhausted.get() {
return Err(objiter::new_stop_iteration(vm));
}
match self.entries.borrow_mut().next() {
Some(entry) => match entry {
Ok(entry) => Ok(DirEntry { entry }.into_ref(vm).into_object()),
Err(s) => Err(convert_io_error(vm, s)),
},
None => Err(objiter::new_stop_iteration(vm)),
None => {
self.exhausted.set(true);
Err(objiter::new_stop_iteration(vm))
}
}
}
#[pymethod]
fn close(&self, _vm: &VirtualMachine) {
self.exhausted.set(true);
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__enter__")]
fn enter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__exit__")]
fn exit(zelf: PyRef<Self>, _args: PyFuncArgs, vm: &VirtualMachine) {
zelf.close(vm)
}
}
fn os_scandir(path: PyStringRef, vm: &VirtualMachine) -> PyResult {
match fs::read_dir(path.as_str()) {
fn os_scandir(path: OptionalArg<PyStringRef>, vm: &VirtualMachine) -> PyResult {
let path = match path {
OptionalArg::Present(ref path) => path.as_str(),
OptionalArg::Missing => ".",
};
match fs::read_dir(path) {
Ok(iter) => Ok(ScandirIterator {
entries: RefCell::new(iter),
exhausted: Cell::new(false),
}
.into_ref(vm)
.into_object()),