diff --git a/tests/snippets/stdlib_io_bytesio.py b/tests/snippets/stdlib_io_bytesio.py index f6f97e684..571444876 100644 --- a/tests/snippets/stdlib_io_bytesio.py +++ b/tests/snippets/stdlib_io_bytesio.py @@ -7,6 +7,7 @@ def test_01(): f = BytesIO() f.write(bytes_string) + assert f.tell() == len(bytes_string) assert f.getvalue() == bytes_string def test_02(): @@ -39,12 +40,41 @@ def test_04(): f = BytesIO(string) assert f.read(4) == b'Test' + assert f.tell() == 4 assert f.seek(0) == 0 assert f.read(4) == b'Test' +def test_05(): + """ + Tests that the write method accpets bytearray + """ + bytes_string = b'Test String 5' + + f = BytesIO() + f.write(bytearray(bytes_string)) + + assert f.getvalue() == bytes_string + + +def test_06(): + """ + Tests readline + """ + bytes_string = b'Test String 6\nnew line is here\nfinished' + + f = BytesIO(bytes_string) + + assert f.readline() == b'Test String 6\n' + assert f.readline() == b'new line is here\n' + assert f.readline() == b'finished' + assert f.readline() == b'' + + if __name__ == "__main__": test_01() test_02() test_03() test_04() + test_05() + test_06() diff --git a/tests/snippets/stdlib_io_stringio.py b/tests/snippets/stdlib_io_stringio.py index 0b4182527..828f0506e 100644 --- a/tests/snippets/stdlib_io_stringio.py +++ b/tests/snippets/stdlib_io_stringio.py @@ -10,6 +10,7 @@ def test_01(): f = StringIO() f.write(string) + assert f.tell() == len(string) assert f.getvalue() == string def test_02(): @@ -46,11 +47,26 @@ def test_04(): f = StringIO(string) assert f.read(4) == 'Test' + assert f.tell() == 4 assert f.seek(0) == 0 assert f.read(4) == 'Test' +def test_05(): + """ + Tests readline + """ + string = 'Test String 6\nnew line is here\nfinished' + + f = StringIO(string) + + assert f.readline() == 'Test String 6\n' + assert f.readline() == 'new line is here\n' + assert f.readline() == 'finished' + assert f.readline() == '' + if __name__ == "__main__": test_01() test_02() test_03() test_04() + test_05() diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index bee7e1d7f..87ab1dcd4 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -11,6 +11,7 @@ use num_traits::ToPrimitive; use super::os; use crate::function::{OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbytearray::PyByteArray; +use crate::obj::objbyteinner::PyBytesLike; use crate::obj::objbytes; use crate::obj::objbytes::PyBytes; use crate::obj::objint::{self, PyIntRef}; @@ -81,6 +82,19 @@ impl BufferedIO { Some(buffer) } + + fn tell(&self) -> u64 { + self.cursor.position() + } + + fn readline(&mut self) -> Option { + let mut buf = String::new(); + + match self.cursor.read_line(&mut buf) { + Ok(_) => Some(buf), + Err(_) => None, + } + } } #[derive(Debug)] @@ -141,6 +155,17 @@ impl PyStringIORef { Err(_) => Err(vm.new_value_error("Error Retrieving Value".to_string())), } } + + fn tell(self, _vm: &VirtualMachine) -> u64 { + self.buffer.borrow().tell() + } + + fn readline(self, vm: &VirtualMachine) -> PyResult { + match self.buffer.borrow_mut().readline() { + Some(line) => Ok(line), + None => Err(vm.new_value_error("Error Performing Operation".to_string())), + } + } } fn string_io_new( @@ -173,11 +198,11 @@ impl PyValue for PyBytesIO { } impl PyBytesIORef { - fn write(self, data: objbytes::PyBytesRef, vm: &VirtualMachine) -> PyResult { - let bytes = data.get_value(); + fn write(self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult { + let bytes = data.to_cow(); - match self.buffer.borrow_mut().write(bytes) { - Some(value) => Ok(vm.ctx.new_int(value)), + match self.buffer.borrow_mut().write(&bytes) { + Some(value) => Ok(value), None => Err(vm.new_type_error("Error Writing Bytes".to_string())), } } @@ -207,6 +232,17 @@ impl PyBytesIORef { fn seekable(self, _vm: &VirtualMachine) -> bool { true } + + fn tell(self, _vm: &VirtualMachine) -> u64 { + self.buffer.borrow().tell() + } + + fn readline(self, vm: &VirtualMachine) -> PyResult> { + match self.buffer.borrow_mut().readline() { + Some(line) => Ok(line.as_bytes().to_vec()), + None => Err(vm.new_value_error("Error Performing Operation".to_string())), + } + } } fn bytes_io_new( @@ -720,7 +756,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "seekable" => ctx.new_rustfunc(PyStringIORef::seekable), "read" => ctx.new_rustfunc(PyStringIORef::read), "write" => ctx.new_rustfunc(PyStringIORef::write), - "getvalue" => ctx.new_rustfunc(PyStringIORef::getvalue) + "getvalue" => ctx.new_rustfunc(PyStringIORef::getvalue), + "tell" => ctx.new_rustfunc(PyStringIORef::tell), + "readline" => ctx.new_rustfunc(PyStringIORef::readline), }); //BytesIO: in-memory bytes @@ -731,7 +769,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "seek" => ctx.new_rustfunc(PyBytesIORef::seek), "seekable" => ctx.new_rustfunc(PyBytesIORef::seekable), "write" => ctx.new_rustfunc(PyBytesIORef::write), - "getvalue" => ctx.new_rustfunc(PyBytesIORef::getvalue) + "getvalue" => ctx.new_rustfunc(PyBytesIORef::getvalue), + "tell" => ctx.new_rustfunc(PyBytesIORef::tell), + "readline" => ctx.new_rustfunc(PyBytesIORef::readline), }); py_module!(vm, "_io", {