diff --git a/tests/snippets/stdlib_io.py b/tests/snippets/stdlib_io.py index 78f1287a9..ea52ad3e9 100644 --- a/tests/snippets/stdlib_io.py +++ b/tests/snippets/stdlib_io.py @@ -30,3 +30,9 @@ fi.read() fi.close() with assertRaises(ValueError): fi.read() + +with FileIO('README.md') as fio: + nres = fio.read(1) + assert len(nres) == 1 + nres = fio.read(2) + assert len(nres) == 2 \ No newline at end of file diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 1178724f4..3e60ba81f 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -346,18 +346,36 @@ fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } fn file_io_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(file_io, None)]); + arg_check!( + vm, + args, + required = [(file_io, None)], + optional = [(read_byte, Some(vm.ctx.int_type()))] + ); let file_no = vm.get_attribute(file_io.clone(), "fileno")?; let raw_fd = objint::get_value(&file_no).to_i64().unwrap(); let mut handle = os::rust_file(raw_fd); - let mut bytes = vec![]; - match handle.read_to_end(&mut bytes) { - Ok(_) => {} - Err(_) => return Err(vm.new_value_error("Error reading from Buffer".to_string())), - } + let bytes = match read_byte { + None => { + let mut bytes = vec![]; + handle + .read_to_end(&mut bytes) + .map_err(|_| vm.new_value_error("Error reading from Buffer".to_string()))?; + bytes + } + Some(read_byte) => { + let mut bytes = vec![0; objint::get_value(&read_byte).to_usize().unwrap()]; + handle + .read_exact(&mut bytes) + .map_err(|_| vm.new_value_error("Error reading from Buffer".to_string()))?; + let updated = os::raw_file_number(handle); + vm.set_attr(file_io, "fileno", vm.ctx.new_int(updated))?; + bytes + } + }; Ok(vm.ctx.new_bytes(bytes)) }