From ab658a11a6688d649c3bcdafa159b293c6a6e4ce Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Thu, 13 May 2021 18:02:01 -0500 Subject: [PATCH] Call the encoder for TextIOWrapper.write --- Lib/test/test_io.py | 12 +----------- vm/src/stdlib/io.rs | 41 +++++++++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index a58ce3be4..430050e56 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -3492,8 +3492,6 @@ class TextIOWrapperTest(unittest.TestCase): t = self.TextIOWrapper(self.StringIO('a')) self.assertRaises(TypeError, t.read) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_illegal_encoder(self): # Issue 31271: Calling write() while the return value of encoder's # encode() is invalid shouldn't cause an assertion failure. @@ -3779,11 +3777,6 @@ class CTextIOWrapperTest(TextIOWrapperTest): def test_repr(self): super().test_repr() - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_encoding_errors_writing(self): - super().test_encoding_errors_writing() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_newlines(self): @@ -4493,6 +4486,7 @@ class SignalsTest(unittest.TestCase): def test_interrupted_write_buffered(self): self.check_interrupted_write(b"xy", b"xy", mode="wb") + @unittest.skip("TODO: RUSTPYTHON, hangs?") def test_interrupted_write_text(self): self.check_interrupted_write("xy", b"xy", mode="w", encoding="ascii") @@ -4654,10 +4648,6 @@ class PySignalsTest(SignalsTest): test_reentrant_write_buffered = None test_reentrant_write_text = None - @unittest.skip("TODO: RUSTPYTHON, hangs?") - def test_interrupted_write_text(self): - super().test_interrupted_write_text() - def load_tests(*args): tests = (CIOTest, PyIOTest, APIMismatchTest, diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 868ebaba7..9449486e6 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -1930,7 +1930,11 @@ mod _io { } } - type EncodeFunc = fn(); + // TODO: implement legit fast-paths for other encodings + type EncodeFunc = fn(PyStrRef) -> PendingWrite; + fn textio_encode_utf8(s: PyStrRef) -> PendingWrite { + PendingWrite::Utf8(s) + } #[derive(Debug)] struct TextIOData { @@ -1974,16 +1978,14 @@ mod _io { #[derive(Debug)] enum PendingWrite { - Str(PyStrRef), - // TODO: encode() str's when encoding != utf8 - #[allow(unused)] + Utf8(PyStrRef), Bytes(PyBytesRef), } impl PendingWrite { fn as_bytes(&self) -> &[u8] { match self { - Self::Str(s) => s.borrow_value().as_bytes(), + Self::Utf8(s) => s.borrow_value().as_bytes(), Self::Bytes(b) => b.borrow_value(), } } @@ -2163,7 +2165,7 @@ mod _io { let encodefunc = encoding_name.and_then(|name| { name.payload::() .and_then(|name| match name.borrow_value() { - "utf-8" => Some((|| {}) as fn()), + "utf-8" => Some(textio_encode_utf8 as EncodeFunc), _ => None, }) }); @@ -2551,9 +2553,10 @@ mod _io { let mut textio = self.lock(vm)?; textio.check_closed(vm)?; - if textio.encoder.is_none() { - return Err(new_unsupported_operation(vm, "not writable".to_owned())); - } + let (encoder, encodefunc) = textio + .encoder + .as_ref() + .ok_or_else(|| new_unsupported_operation(vm, "not writable".to_owned()))?; let char_len = obj.char_len(); @@ -2580,7 +2583,25 @@ mod _io { } else { obj }; - let chunk = PendingWrite::Str(chunk); + let chunk = if let Some(encodefunc) = *encodefunc { + encodefunc(chunk) + } else { + let b = vm.call_method(encoder, "encode", (chunk.clone(),))?; + b.downcast::() + .map(PendingWrite::Bytes) + .or_else(|obj| { + // TODO: not sure if encode() returning the str it was passed is officially + // supported or just a quirk of how the CPython code is written + if obj.is(&chunk) { + Ok(PendingWrite::Utf8(chunk)) + } else { + Err(vm.new_type_error(format!( + "encoder should return a bytes object, not '{}'", + obj.class().name + ))) + } + })? + }; if textio.pending.num_bytes + chunk.as_bytes().len() > textio.chunk_size { textio.write_pending(vm)?; }