From 12a30288062ee8c076f05ff29f8d2204ea42c5b9 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 13 Jun 2019 20:41:37 +0300 Subject: [PATCH 1/6] Import all from_list in one __import__ call --- compiler/src/bytecode.rs | 4 +-- compiler/src/compile.rs | 56 ++++++++++++++++++++++++++----------- compiler/src/symboltable.rs | 20 +++++++++---- parser/src/ast.rs | 9 ++++-- parser/src/python.lalrpop | 23 ++++++++------- vm/src/frame.rs | 38 ++++++++++++------------- 6 files changed, 93 insertions(+), 57 deletions(-) diff --git a/compiler/src/bytecode.rs b/compiler/src/bytecode.rs index aede69c93..c87bc0ae0 100644 --- a/compiler/src/bytecode.rs +++ b/compiler/src/bytecode.rs @@ -50,7 +50,7 @@ pub enum NameScope { pub enum Instruction { Import { name: String, - symbol: Option, + symbols: Vec, }, ImportStar { name: String, @@ -330,7 +330,7 @@ impl Instruction { } match self { - Import { name, symbol } => w!(Import, name, format!("{:?}", symbol)), + Import { name, symbols } => w!(Import, name, format!("{:?}", symbols)), ImportStar { name } => w!(ImportStar, name), LoadName { name, scope } => w!(LoadName, name, format!("{:?}", scope)), StoreName { name, scope } => w!(StoreName, name, format!("{:?}", scope)), diff --git a/compiler/src/compile.rs b/compiler/src/compile.rs index 30510c200..dc83fad2e 100644 --- a/compiler/src/compile.rs +++ b/compiler/src/compile.rs @@ -250,29 +250,51 @@ impl Compiler { ast::Statement::Import { import_parts } => { for ast::SingleImport { module, - symbol, + symbols, alias, } in import_parts { - match symbol { - Some(name) if name == "*" => { - self.emit(Instruction::ImportStar { - name: module.clone(), - }); - } - _ => { + if let Some(alias) = alias { + self.emit(Instruction::Import { + name: module.clone(), + symbols: vec![], + }); + self.store_name(&alias); + } else { + if symbols.is_empty() { self.emit(Instruction::Import { name: module.clone(), - symbol: symbol.clone(), + symbols: vec![], }); - let name = match alias { - Some(alias) => alias.clone(), - None => match symbol { - Some(symbol) => symbol.clone(), - None => module.clone(), - }, - }; - self.store_name(&name); + self.store_name(&module.clone()); + } else { + let mut import_star = false; + let mut symbols_strings = vec![]; + let mut names = vec![]; + for ast::ImportSymbol { symbol, alias } in symbols { + if symbol == "*" { + import_star = true; + } + symbols_strings.push(symbol.to_string()); + names.insert( + 0, + match alias { + Some(alias) => alias, + None => symbol, + }, + ); + } + if import_star { + self.emit(Instruction::ImportStar { + name: module.clone(), + }); + } else { + self.emit(Instruction::Import { + name: module.clone(), + symbols: symbols_strings, + }); + names.iter().for_each(|name| self.store_name(&name)); + } } } } diff --git a/compiler/src/symboltable.rs b/compiler/src/symboltable.rs index 9687b0ec7..04e5d9ede 100644 --- a/compiler/src/symboltable.rs +++ b/compiler/src/symboltable.rs @@ -300,14 +300,22 @@ impl SymbolTableBuilder { for part in import_parts { if let Some(alias) = &part.alias { // `import mymodule as myalias` - // `from mymodule import myimportname as myalias` self.register_name(alias, SymbolRole::Assigned)?; - } else if let Some(symbol) = &part.symbol { - // `from mymodule import myimport` - self.register_name(symbol, SymbolRole::Assigned)?; } else { - // `import module` - self.register_name(&part.module, SymbolRole::Assigned)?; + if part.symbols.is_empty() { + // `import module` + self.register_name(&part.module, SymbolRole::Assigned)?; + } else { + // `from mymodule import myimport` + for symbol in &part.symbols { + if let Some(alias) = &symbol.alias { + // `from mymodule import myimportname as myalias` + self.register_name(alias, SymbolRole::Assigned)?; + } else { + self.register_name(&symbol.symbol, SymbolRole::Assigned)?; + } + } + } } } } diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 941e04131..8fc2ba53d 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -27,12 +27,17 @@ pub struct Program { pub statements: Vec, } +#[derive(Debug, PartialEq)] +pub struct ImportSymbol { + pub symbol: String, + pub alias: Option, +} + #[derive(Debug, PartialEq)] pub struct SingleImport { pub module: String, - // (symbol name in module, name it should be assigned locally) - pub symbol: Option, pub alias: Option, + pub symbols: Vec, } #[derive(Debug, PartialEq)] diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index 8d6dd7cac..ac8669b8e 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -206,7 +206,7 @@ ImportStatement: ast::LocatedStatement = { .map(|(n, a)| ast::SingleImport { module: n.to_string(), - symbol: None, + symbols: vec![], alias: a.clone() }) .collect() @@ -217,15 +217,18 @@ ImportStatement: ast::LocatedStatement = { ast::LocatedStatement { location: loc, node: ast::Statement::Import { - import_parts: i - .iter() - .map(|(i, a)| - ast::SingleImport { - module: n.to_string(), - symbol: Some(i.to_string()), - alias: a.clone() - }) - .collect() + import_parts: vec![ + ast::SingleImport { + module: n.to_string(), + symbols: i.iter() + .map(|(i, a)| + ast::ImportSymbol { + symbol: i.to_string(), + alias: a.clone(), + }) + .collect(), + alias: None + }] }, } }, diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 2506aa455..8aaed0708 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -357,8 +357,8 @@ impl Frame { } bytecode::Instruction::Import { ref name, - ref symbol, - } => self.import(vm, name, symbol), + ref symbols, + } => self.import(vm, name, symbols), bytecode::Instruction::ImportStar { ref name } => self.import_star(vm, name), bytecode::Instruction::LoadName { ref name, @@ -907,25 +907,23 @@ impl Frame { } } - fn import(&self, vm: &VirtualMachine, module: &str, symbol: &Option) -> FrameResult { - let from_list = match symbol { - Some(symbol) => vm.ctx.new_tuple(vec![vm.ctx.new_str(symbol.to_string())]), - None => vm.ctx.new_tuple(vec![]), - }; - let module = vm.import(module, &from_list)?; + fn import(&self, vm: &VirtualMachine, module: &str, symbols: &Vec) -> FrameResult { + let mut from_list = vec![]; + for symbol in symbols { + from_list.push(vm.ctx.new_str(symbol.to_string())); + } + let module = vm.import(module, &vm.ctx.new_tuple(from_list))?; - // If we're importing a symbol, look it up and use it, otherwise construct a module and return - // that - let obj = match symbol { - Some(symbol) => vm.get_attribute(module, symbol.as_str()).map_err(|_| { - let import_error = vm.context().exceptions.import_error.clone(); - vm.new_exception(import_error, format!("cannot import name '{}'", symbol)) - }), - None => Ok(module), - }; - - // Push module on stack: - self.push_value(obj?); + if symbols.is_empty() { + self.push_value(module); + } else { + for symbol in symbols { + let obj = vm + .get_attribute(module.clone(), symbol.as_str()) + .map_err(|_| vm.new_import_error(format!("cannot import name '{}'", symbol))); + self.push_value(obj?); + } + } Ok(None) } From 4938c03d6f1938f88692e910d06adbf39b3bf69f Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 14 Jun 2019 08:45:30 +0300 Subject: [PATCH 2/6] Improve compiler import --- compiler/src/compile.rs | 61 ++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/compiler/src/compile.rs b/compiler/src/compile.rs index dc83fad2e..8b806b1d1 100644 --- a/compiler/src/compile.rs +++ b/compiler/src/compile.rs @@ -255,46 +255,45 @@ impl Compiler { } in import_parts { if let Some(alias) = alias { + // import module as alias self.emit(Instruction::Import { name: module.clone(), symbols: vec![], }); self.store_name(&alias); + } else if symbols.is_empty() { + // import module + self.emit(Instruction::Import { + name: module.clone(), + symbols: vec![], + }); + self.store_name(&module.clone()); } else { - if symbols.is_empty() { + let import_star = symbols + .iter() + .any(|import_symbol| import_symbol.symbol == "*"); + if import_star { + // from module import * + self.emit(Instruction::ImportStar { + name: module.clone(), + }); + } else { + // from module import symbol + // from module import symbol as alias + let (names, symbols_strings): (Vec, Vec) = symbols + .iter() + .map(|ast::ImportSymbol { symbol, alias }| { + ( + alias.clone().unwrap_or_else(|| symbol.to_string()), + symbol.to_string(), + ) + }) + .unzip(); self.emit(Instruction::Import { name: module.clone(), - symbols: vec![], + symbols: symbols_strings, }); - self.store_name(&module.clone()); - } else { - let mut import_star = false; - let mut symbols_strings = vec![]; - let mut names = vec![]; - for ast::ImportSymbol { symbol, alias } in symbols { - if symbol == "*" { - import_star = true; - } - symbols_strings.push(symbol.to_string()); - names.insert( - 0, - match alias { - Some(alias) => alias, - None => symbol, - }, - ); - } - if import_star { - self.emit(Instruction::ImportStar { - name: module.clone(), - }); - } else { - self.emit(Instruction::Import { - name: module.clone(), - symbols: symbols_strings, - }); - names.iter().for_each(|name| self.store_name(&name)); - } + names.iter().rev().for_each(|name| self.store_name(&name)); } } } From e61fa6dd7482d4bdfebddad1462c596604ba9853 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 14 Jun 2019 08:49:20 +0300 Subject: [PATCH 3/6] Use Iterator to create from_list --- vm/src/frame.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 8aaed0708..8ea114854 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -908,10 +908,10 @@ impl Frame { } fn import(&self, vm: &VirtualMachine, module: &str, symbols: &Vec) -> FrameResult { - let mut from_list = vec![]; - for symbol in symbols { - from_list.push(vm.ctx.new_str(symbol.to_string())); - } + let from_list = symbols + .iter() + .map(|symbol| vm.ctx.new_str(symbol.to_string())) + .collect(); let module = vm.import(module, &vm.ctx.new_tuple(from_list))?; if symbols.is_empty() { From 60e799727f9945510e9de5a58c01fc2d95d09682 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 14 Jun 2019 15:25:24 +0300 Subject: [PATCH 4/6] Support reversed on sequence --- tests/snippets/builtin_reversed.py | 3 +++ vm/src/builtins.rs | 22 ++++++++++++++++------ vm/src/obj/objiter.rs | 29 ++++++++++++++++++----------- vm/src/pyobject.rs | 1 + 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/tests/snippets/builtin_reversed.py b/tests/snippets/builtin_reversed.py index 2bbfcb98a..261b5c326 100644 --- a/tests/snippets/builtin_reversed.py +++ b/tests/snippets/builtin_reversed.py @@ -1 +1,4 @@ assert list(reversed(range(5))) == [4, 3, 2, 1, 0] + +l = [5,4,3,2,1] +assert list(reversed(l)) == [1,2,3,4,5] diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 9779a5054..6982e0493 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -2,12 +2,13 @@ //! //! Implements functions listed here: https://docs.python.org/3/library/builtins.html +use std::cell::Cell; use std::char; use std::io::{self, Write}; use std::str; use num_bigint::Sign; -use num_traits::{Signed, Zero}; +use num_traits::{Signed, ToPrimitive, Zero}; use crate::compile; use crate::obj::objbool; @@ -684,11 +685,20 @@ fn builtin_repr(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult fn builtin_reversed(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); - // TODO: fallback to using __len__ and __getitem__, if object supports sequence protocol - let method = vm.get_method_or_type_error(obj.clone(), "__reversed__", || { - format!("argument to reversed() must be a sequence") - })?; - vm.invoke(method, PyFuncArgs::default()) + if let Some(reversed_method) = vm.get_method(obj.clone(), "__reversed__") { + vm.invoke(reversed_method?, PyFuncArgs::default()) + } else { + vm.get_method_or_type_error(obj.clone(), "__getitem__", || { + format!("argument to reversed() must be a sequence") + })?; + let len = vm.call_method(&obj.clone(), "__len__", PyFuncArgs::default())?; + let obj_iterator = objiter::PySequenceIterator { + position: Cell::new(objint::get_value(&len).to_isize().unwrap() - 1), + obj: obj.clone(), + reversed: true, + }; + Ok(obj_iterator.into_ref(vm).into_object()) + } } fn builtin_round(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 9aebc85f6..0fcd078c7 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -28,6 +28,7 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: &PyObjectRef) -> PyResult { let obj_iterator = PySequenceIterator { position: Cell::new(0), obj: iter_target.clone(), + reversed: false, }; Ok(obj_iterator.into_ref(vm).into_object()) } @@ -80,8 +81,9 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef { #[pyclass] #[derive(Debug)] pub struct PySequenceIterator { - pub position: Cell, + pub position: Cell, pub obj: PyObjectRef, + pub reversed: bool, } impl PyValue for PySequenceIterator { @@ -94,17 +96,22 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let number = vm.ctx.new_int(self.position.get()); - match vm.call_method(&self.obj, "__getitem__", vec![number]) { - Ok(val) => { - self.position.set(self.position.get() + 1); - Ok(val) + if self.position.get() >= 0 { + let step: isize = if self.reversed { -1 } else { 1 }; + let number = vm.ctx.new_int(self.position.get()); + match vm.call_method(&self.obj, "__getitem__", vec![number]) { + Ok(val) => { + self.position.set(self.position.get() + step); + Ok(val) + } + Err(ref e) if objtype::isinstance(&e, &vm.ctx.exceptions.index_error) => { + Err(new_stop_iteration(vm)) + } + // also catches stop_iteration => stop_iteration + Err(e) => Err(e), } - Err(ref e) if objtype::isinstance(&e, &vm.ctx.exceptions.index_error) => { - Err(new_stop_iteration(vm)) - } - // also catches stop_iteration => stop_iteration - Err(e) => Err(e), + } else { + Err(new_stop_iteration(vm)) } } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index d0bccd856..e1c2cd42c 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1149,6 +1149,7 @@ where objiter::PySequenceIterator { position: Cell::new(0), obj: obj.clone(), + reversed: false, } .into_ref(vm) .into_object(), From 82f83ef345f02c5afa45a8f54f0e898f6845eb1a Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 14 Jun 2019 19:20:00 +0300 Subject: [PATCH 5/6] Support more open flags --- Cargo.lock | 1 + tests/snippets/stdlib_os.py | 16 +++++---- vm/Cargo.toml | 1 + vm/src/builtins.rs | 1 + vm/src/exceptions.rs | 3 ++ vm/src/stdlib/io.rs | 25 ++++++++------ vm/src/stdlib/os.rs | 65 +++++++++++++++++++++++++++++-------- 7 files changed, 82 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7bb6469d7..ad9272003 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -939,6 +939,7 @@ name = "rustpython_vm" version = "0.1.0" dependencies = [ "bincode 1.1.4 (registry+https://github.com/rust-lang/crates.io-index)", + "bitflags 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "caseless 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "crc 1.8.1 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/tests/snippets/stdlib_os.py b/tests/snippets/stdlib_os.py index fe0b75721..6280fc3bb 100644 --- a/tests/snippets/stdlib_os.py +++ b/tests/snippets/stdlib_os.py @@ -4,12 +4,13 @@ import stat from testutils import assert_raises -fd = os.open('README.md', 0) +fd = os.open('README.md', os.O_RDONLY) assert fd > 0 os.close(fd) assert_raises(OSError, lambda: os.read(fd, 10)) -assert_raises(FileNotFoundError, lambda: os.open('DOES_NOT_EXIST', 0)) +assert_raises(FileNotFoundError, lambda: os.open('DOES_NOT_EXIST', os.O_RDONLY)) +assert_raises(FileNotFoundError, lambda: os.open('DOES_NOT_EXIST', os.O_WRONLY)) assert os.O_RDONLY == 0 @@ -88,14 +89,17 @@ CONTENT3 = b"BOYA" with TestWithTempDir() as tmpdir: fname = os.path.join(tmpdir, FILE_NAME) - with open(fname, "wb"): - pass - fd = os.open(fname, 1) + fd = os.open(fname, os.O_WRONLY | os.O_CREAT | os.O_EXCL) assert os.write(fd, CONTENT2) == len(CONTENT2) + os.close(fd) + + fd = os.open(fname, os.O_WRONLY | os.O_APPEND) assert os.write(fd, CONTENT3) == len(CONTENT3) os.close(fd) - fd = os.open(fname, 0) + assert_raises(FileExistsError, lambda: os.open(fname, os.O_WRONLY | os.O_CREAT | os.O_EXCL)) + + fd = os.open(fname, os.O_RDONLY) assert os.read(fd, len(CONTENT2)) == CONTENT2 assert os.read(fd, len(CONTENT3)) == CONTENT3 os.close(fd) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index de53223ae..c7159acd6 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -32,6 +32,7 @@ indexmap = "1.0.2" crc = "^1.0.0" bincode = "1.1.4" unicode_categories = "0.1.1" +bitflags = "1.1" # TODO: release and publish to crates.io diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 6982e0493..570b1fbbf 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -856,6 +856,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) { "IndexError" => ctx.exceptions.index_error.clone(), "ImportError" => ctx.exceptions.import_error.clone(), "FileNotFoundError" => ctx.exceptions.file_not_found_error.clone(), + "FileExistsError" => ctx.exceptions.file_exists_error.clone(), "StopIteration" => ctx.exceptions.stop_iteration.clone(), "ZeroDivisionError" => ctx.exceptions.zero_division_error.clone(), "KeyError" => ctx.exceptions.key_error.clone(), diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 4f1f94744..0c243d0a1 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -148,6 +148,7 @@ pub struct ExceptionZoo { pub base_exception_type: PyClassRef, pub exception_type: PyClassRef, pub file_not_found_error: PyClassRef, + pub file_exists_error: PyClassRef, pub import_error: PyClassRef, pub index_error: PyClassRef, pub key_error: PyClassRef, @@ -203,6 +204,7 @@ impl ExceptionZoo { let not_implemented_error = create_type("NotImplementedError", &type_type, &runtime_error); let file_not_found_error = create_type("FileNotFoundError", &type_type, &os_error); let permission_error = create_type("PermissionError", &type_type, &os_error); + let file_exists_error = create_type("FileExistsError", &type_type, &os_error); let warning = create_type("Warning", &type_type, &exception_type); let bytes_warning = create_type("BytesWarning", &type_type, &warning); @@ -224,6 +226,7 @@ impl ExceptionZoo { base_exception_type, exception_type, file_not_found_error, + file_exists_error, import_error, index_error, key_error, diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 2dafc7137..d7db77bc6 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -21,16 +21,6 @@ use crate::obj::objtype::PyClassRef; use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; -fn compute_c_flag(mode: &str) -> u16 { - match mode { - "w" => 512, - "x" => 512, - "a" => 8, - "+" => 2, - _ => 0, - } -} - #[derive(Debug)] struct PyStringIO { data: RefCell, @@ -132,6 +122,21 @@ fn buffered_reader_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bytes(result)) } +fn compute_c_flag(mode: &str) -> u32 { + let flags = match mode { + "w" => os::FileCreationFlags::O_WRONLY | os::FileCreationFlags::O_CREAT, + "x" => { + os::FileCreationFlags::O_WRONLY + | os::FileCreationFlags::O_CREAT + | os::FileCreationFlags::O_EXCL + } + "a" => os::FileCreationFlags::O_APPEND, + "+" => os::FileCreationFlags::O_RDWR, + _ => os::FileCreationFlags::O_RDONLY, + }; + flags.bits() +} + fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 14d720df6..e6ab64914 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -5,6 +5,7 @@ use std::io::{self, ErrorKind, Read, Write}; use std::time::{Duration, SystemTime}; use std::{env, fs}; +use bitflags::bitflags; use num_traits::cast::ToPrimitive; use crate::function::{IntoPyNativeFunc, PyFuncArgs}; @@ -81,13 +82,26 @@ pub fn os_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +bitflags! { + pub struct FileCreationFlags: u32 { + // https://elixir.bootlin.com/linux/v4.8/source/include/uapi/asm-generic/fcntl.h + const O_RDONLY = 0o0000_0000; + const O_WRONLY = 0o0000_0001; + const O_RDWR = 0o0000_0002; + const O_CREAT = 0o0000_0100; + const O_EXCL = 0o0000_0200; + const O_APPEND = 0o0000_2000; + const O_NONBLOCK = 0o0000_4000; + } +} + pub fn os_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, required = [ (name, Some(vm.ctx.str_type())), - (mode, Some(vm.ctx.int_type())) + (flags, Some(vm.ctx.int_type())) ], optional = [(dir_fd, Some(vm.ctx.int_type()))] ); @@ -102,14 +116,32 @@ pub fn os_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { }; let fname = &make_path(vm, name, &dir_fd).value; - let handle = match objint::get_value(mode).to_u16().unwrap() { - 0 => OpenOptions::new().read(true).open(&fname), - 1 => OpenOptions::new().write(true).open(&fname), - 2 => OpenOptions::new().read(true).write(true).open(&fname), - 512 => OpenOptions::new().write(true).create(true).open(&fname), - _ => OpenOptions::new().read(true).open(&fname), + let flags = FileCreationFlags::from_bits(objint::get_value(flags).to_u32().unwrap()) + .ok_or(vm.new_value_error("Unsupported flag".to_string()))?; + + let mut options = &mut OpenOptions::new(); + + if flags.contains(FileCreationFlags::O_WRONLY) { + options = options.write(true); + } else if flags.contains(FileCreationFlags::O_RDWR) { + options = options.read(true).write(true); + } else { + options = options.read(true); } - .map_err(|err| match err.kind() { + + if flags.contains(FileCreationFlags::O_APPEND) { + options = options.append(true); + } + + if flags.contains(FileCreationFlags::O_CREAT) { + if flags.contains(FileCreationFlags::O_EXCL) { + options = options.create_new(true); + } else { + options = options.create(true); + } + } + + let handle = options.open(&fname).map_err(|err| match err.kind() { ErrorKind::NotFound => { let exc_type = vm.ctx.exceptions.file_not_found_error.clone(); vm.new_exception(exc_type, format!("No such file or directory: {}", &fname)) @@ -118,6 +150,10 @@ pub fn os_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { let exc_type = vm.ctx.exceptions.permission_error.clone(); vm.new_exception(exc_type, format!("Permission denied: {}", &fname)) } + ErrorKind::AlreadyExists => { + let exc_type = vm.ctx.exceptions.file_exists_error.clone(); + vm.new_exception(exc_type, format!("File exists: {}", &fname)) + } _ => vm.new_value_error("Unhandled file IO error".to_string()), })?; @@ -743,12 +779,13 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "getcwd" => ctx.new_rustfunc(os_getcwd), "chdir" => ctx.new_rustfunc(os_chdir), "fspath" => ctx.new_rustfunc(os_fspath), - "O_RDONLY" => ctx.new_int(0), - "O_WRONLY" => ctx.new_int(1), - "O_RDWR" => ctx.new_int(2), - "O_NONBLOCK" => ctx.new_int(4), - "O_APPEND" => ctx.new_int(8), - "O_CREAT" => ctx.new_int(512) + "O_RDONLY" => ctx.new_int(FileCreationFlags::O_RDONLY.bits()), + "O_WRONLY" => ctx.new_int(FileCreationFlags::O_WRONLY.bits()), + "O_RDWR" => ctx.new_int(FileCreationFlags::O_RDWR.bits()), + "O_NONBLOCK" => ctx.new_int(FileCreationFlags::O_NONBLOCK.bits()), + "O_APPEND" => ctx.new_int(FileCreationFlags::O_APPEND.bits()), + "O_EXCL" => ctx.new_int(FileCreationFlags::O_EXCL.bits()), + "O_CREAT" => ctx.new_int(FileCreationFlags::O_CREAT.bits()) }); for support in support_funcs { From 92ad30ef6a76d4f2e37bbcdd6bbacfdefe858715 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 15 Jun 2019 17:02:46 +0300 Subject: [PATCH 6/6] Add mode argument to os.open --- vm/src/stdlib/os.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index e6ab64914..60fc6b11b 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -103,7 +103,10 @@ pub fn os_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { (name, Some(vm.ctx.str_type())), (flags, Some(vm.ctx.int_type())) ], - optional = [(dir_fd, Some(vm.ctx.int_type()))] + optional = [ + (_mode, Some(vm.ctx.int_type())), + (dir_fd, Some(vm.ctx.int_type())) + ] ); let name = name.clone().downcast::().unwrap();