diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index df04653c6..a12e5893d 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -869,6 +869,11 @@ class UTF16Test(ReadTest, unittest.TestCase): with reader: self.assertEqual(reader.read(), s1) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF16LETest(ReadTest, unittest.TestCase): encoding = "utf-16-le" ill_formed_sequence = b"\x80\xdc" @@ -917,6 +922,11 @@ class UTF16LETest(ReadTest, unittest.TestCase): self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding), "\U00010203") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF16BETest(ReadTest, unittest.TestCase): encoding = "utf-16-be" ill_formed_sequence = b"\xdc\x80" @@ -965,6 +975,11 @@ class UTF16BETest(ReadTest, unittest.TestCase): self.assertEqual(b'\xd8\x00\xde\x03'.decode(self.encoding), "\U00010203") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_incremental_surrogatepass(self): + super().test_incremental_surrogatepass() + class UTF8Test(ReadTest, unittest.TestCase): encoding = "utf-8" ill_formed_sequence = b"\xed\xb2\x80" @@ -998,8 +1013,6 @@ class UTF8Test(ReadTest, unittest.TestCase): self.check_state_handling_decode(self.encoding, u, u.encode(self.encoding)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode_error(self): for data, error_handler, expected in ( (b'[\x80\xff]', 'ignore', '[]'), @@ -1026,8 +1039,6 @@ class UTF8Test(ReadTest, unittest.TestCase): exc = cm.exception self.assertEqual(exc.object[exc.start:exc.end], '\uD800\uDFFF') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogatepass_handler(self): self.assertEqual("abc\ud800def".encode(self.encoding, "surrogatepass"), self.BOM + b"abc\xed\xa0\x80def") @@ -2884,8 +2895,6 @@ class EscapeEncodeTest(unittest.TestCase): class SurrogateEscapeTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_utf8(self): # Bad byte self.assertEqual(b"foo\x80bar".decode("utf-8", "surrogateescape"), @@ -2898,8 +2907,6 @@ class SurrogateEscapeTest(unittest.TestCase): self.assertEqual("\udced\udcb0\udc80".encode("utf-8", "surrogateescape"), b"\xed\xb0\x80") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ascii(self): # bad byte self.assertEqual(b"foo\x80bar".decode("ascii", "surrogateescape"), @@ -2916,8 +2923,6 @@ class SurrogateEscapeTest(unittest.TestCase): self.assertEqual("foo\udca5bar".encode("iso-8859-3", "surrogateescape"), b"foo\xa5bar") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_latin1(self): # Issue6373 self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin-1", "surrogateescape"), @@ -3561,8 +3566,6 @@ class ASCIITest(unittest.TestCase): def test_encode(self): self.assertEqual('abc123'.encode('ascii'), b'abc123') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_error(self): for data, error_handler, expected in ( ('[\x80\xff\u20ac]', 'ignore', b'[]'), @@ -3585,8 +3588,6 @@ class ASCIITest(unittest.TestCase): def test_decode(self): self.assertEqual(b'abc'.decode('ascii'), 'abc') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode_error(self): for data, error_handler, expected in ( (b'[\x80\xff]', 'ignore', '[]'), @@ -3609,8 +3610,6 @@ class Latin1Test(unittest.TestCase): with self.subTest(data=data, expected=expected): self.assertEqual(data.encode('latin1'), expected) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_errors(self): for data, error_handler, expected in ( ('[\u20ac\udc80]', 'ignore', b'[]'), diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py index 682dc7499..af4bb3a63 100644 --- a/Lib/test/test_json/test_scanstring.py +++ b/Lib/test/test_json/test_scanstring.py @@ -86,8 +86,6 @@ class TestScanstring: scanstring('["Bad value", truth]', 2, True), ('Bad value', 12)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): scanstring = self.json.decoder.scanstring def assertScan(given, expect): diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py index aab2d9f7a..fc56ec4af 100644 --- a/Lib/test/test_regrtest.py +++ b/Lib/test/test_regrtest.py @@ -945,7 +945,6 @@ class ArgsTestCase(BaseTestCase): """) self.check_leak(code, 'file descriptors') - @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows') def test_list_tests(self): # test --list-tests tests = [self.create_test() for i in range(5)] @@ -953,7 +952,6 @@ class ArgsTestCase(BaseTestCase): self.assertEqual(output.rstrip().splitlines(), tests) - @unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows') def test_list_cases(self): # test --list-cases code = textwrap.dedent(""" diff --git a/Lib/test/test_stringprep.py b/Lib/test/test_stringprep.py index 118f3f086..d4b4a13d0 100644 --- a/Lib/test/test_stringprep.py +++ b/Lib/test/test_stringprep.py @@ -6,8 +6,6 @@ import unittest from stringprep import * class StringprepTests(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test(self): self.assertTrue(in_table_a1("\u0221")) self.assertFalse(in_table_a1("\u0222")) diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index d7507eb7f..e5b18fe20 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -1198,8 +1198,6 @@ class ProcessTestCase(BaseTestCase): stdout, stderr = popen.communicate(input='') self.assertEqual(stdout, '1\n2\n3\n4') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_communicate_errors(self): for errors, expected in [ ('ignore', ''), diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 4ae81cb99..63f7b347a 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase): format = tarfile.USTAR_FORMAT - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_uname_unicode(self): - super().test_uname_unicode() - # Test whether the utf-8 encoded version of a filename exceeds the 100 # bytes name field limit (every occurrence of '\xff' will be expanded to 2 # bytes). @@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase): format = tarfile.GNU_FORMAT - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_uname_unicode(self): - super().test_uname_unicode() - - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bad_pax_header(self): # Test for issue #8633. GNU tar <= 1.23 creates raw binary fields # without a hdrcharset=BINARY header. @@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase): # PAX_FORMAT ignores encoding in write mode. test_unicode_filename_error = None - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_binary_header(self): # Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field. for encoding, name in ( diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 5c2c6c29b..4da63c54d 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -608,8 +608,6 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual('abc' == bytearray(b'abc'), False) self.assertEqual('abc' != bytearray(b'abc'), True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_comparison(self): # Comparisons: self.assertEqual('abc', 'abc') @@ -830,8 +828,6 @@ class UnicodeTest(string_tests.CommonTest, warnings.simplefilter('ignore', DeprecationWarning) self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier()) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_isprintable(self): self.assertTrue("".isprintable()) self.assertTrue(" ".isprintable()) @@ -847,8 +843,6 @@ class UnicodeTest(string_tests.CommonTest, self.assertTrue('\U0001F46F'.isprintable()) self.assertFalse('\U000E0020'.isprintable()) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_surrogates(self): for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800', 'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'): @@ -1827,8 +1821,6 @@ class UnicodeTest(string_tests.CommonTest, 'ill-formed sequence'): b'+@'.decode('utf-7') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_codecs_utf8(self): self.assertEqual(''.encode('utf-8'), b'') self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac') diff --git a/Lib/test/test_userstring.py b/Lib/test/test_userstring.py index c0017794e..51b4f6041 100644 --- a/Lib/test/test_userstring.py +++ b/Lib/test/test_userstring.py @@ -53,8 +53,6 @@ class UserStringTest( str3 = ustr3('TEST') self.assertEqual(fmt2 % str3, 'value is TEST') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_default_args(self): self.checkequal(b'hello', 'hello', 'encode') # Check that encoding defaults to utf-8 @@ -62,8 +60,6 @@ class UserStringTest( # Check that errors defaults to 'strict' self.checkraises(UnicodeError, '\ud800', 'encode') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode_explicit_none_args(self): self.checkequal(b'hello', 'hello', 'encode', None, None) # Check that encoding defaults to utf-8 diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index b291d5301..488a67e80 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -730,6 +730,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase): @unittest.skipIf(os_helper.TESTFN_UNENCODABLE is None, "need an unencodable filename") + @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") def testUnencodable(self): filename = os_helper.TESTFN_UNENCODABLE + ".zip" self.addCleanup(os_helper.unlink, filename) diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs index 21c5de28b..57a1d6d7d 100644 --- a/common/src/wtf8/mod.rs +++ b/common/src/wtf8/mod.rs @@ -122,18 +122,18 @@ impl CodePoint { /// Returns the numeric value of the code point if it is a leading surrogate. #[inline] - pub fn to_lead_surrogate(self) -> Option { + pub fn to_lead_surrogate(self) -> Option { match self.value { - lead @ 0xD800..=0xDBFF => Some(lead as u16), + lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)), _ => None, } } /// Returns the numeric value of the code point if it is a trailing surrogate. #[inline] - pub fn to_trail_surrogate(self) -> Option { + pub fn to_trail_surrogate(self) -> Option { match self.value { - trail @ 0xDC00..=0xDFFF => Some(trail as u16), + trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)), _ => None, } } @@ -216,6 +216,18 @@ impl PartialEq for char { } } +#[derive(Clone, Copy)] +pub struct LeadSurrogate(u16); + +#[derive(Clone, Copy)] +pub struct TrailSurrogate(u16); + +impl LeadSurrogate { + pub fn merge(self, trail: TrailSurrogate) -> char { + decode_surrogate_pair(self.0, trail.0) + } +} + /// An owned, growable string of well-formed WTF-8 data. /// /// Similar to `String`, but can additionally contain surrogate code points @@ -291,6 +303,14 @@ impl Wtf8Buf { Wtf8Buf { bytes: value } } + /// Create a WTF-8 string from a WTF-8 byte vec. + pub fn from_bytes(value: Vec) -> Result> { + match Wtf8::from_bytes(&value) { + Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }), + None => Err(value), + } + } + /// Creates a WTF-8 string from a UTF-8 `String`. /// /// This takes ownership of the `String` and does not copy. @@ -750,15 +770,10 @@ impl Wtf8 { } fn decode_surrogate(b: &[u8]) -> Option { - let [a, b, c, ..] = *b else { return None }; - if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 { - // it's a three-byte code - let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f); - let 0xD800..=0xDFFF = c else { return None }; - Some(CodePoint { value: c }) - } else { - None - } + let [0xed, b2 @ (0xa0..), b3, ..] = *b else { + return None; + }; + Some(decode_surrogate(b2, b3).into()) } /// Returns the length, in WTF-8 bytes. @@ -914,14 +929,6 @@ impl Wtf8 { } } - #[inline] - fn final_lead_surrogate(&self) -> Option { - match self.bytes { - [.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)), - _ => None, - } - } - pub fn is_code_point_boundary(&self, index: usize) -> bool { is_code_point_boundary(self, index) } @@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 { 0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F } +#[inline] +fn decode_surrogate_pair(lead: u16, trail: u16) -> char { + let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32); + unsafe { char::from_u32_unchecked(code_point) } +} + /// Copied from str::is_char_boundary #[inline] fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { diff --git a/stdlib/src/json.rs b/stdlib/src/json.rs index aaac0b8be..f970ef5dc 100644 --- a/stdlib/src/json.rs +++ b/stdlib/src/json.rs @@ -13,6 +13,7 @@ mod _json { types::{Callable, Constructor}, }; use malachite_bigint::BigInt; + use rustpython_common::wtf8::Wtf8Buf; use std::str::FromStr; #[pyattr(name = "make_scanner")] @@ -253,8 +254,8 @@ mod _json { end: usize, strict: OptionalArg, vm: &VirtualMachine, - ) -> PyResult<(String, usize)> { - machinery::scanstring(s.as_str(), end, strict.unwrap_or(true)) + ) -> PyResult<(Wtf8Buf, usize)> { + machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true)) .map_err(|e| py_decode_error(e, s, vm)) } } diff --git a/stdlib/src/json/machinery.rs b/stdlib/src/json/machinery.rs index 0614314f4..4612b5263 100644 --- a/stdlib/src/json/machinery.rs +++ b/stdlib/src/json/machinery.rs @@ -28,6 +28,9 @@ use std::io; +use itertools::Itertools; +use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; + static ESCAPE_CHARS: [&str; 0x20] = [ "\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b", "\\t", "\\n", "\\u000", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012", @@ -111,22 +114,22 @@ impl DecodeError { } enum StrOrChar<'a> { - Str(&'a str), - Char(char), + Str(&'a Wtf8), + Char(CodePoint), } impl StrOrChar<'_> { fn len(&self) -> usize { match self { StrOrChar::Str(s) => s.len(), - StrOrChar::Char(c) => c.len_utf8(), + StrOrChar::Char(c) => c.len_wtf8(), } } } pub fn scanstring<'a>( - s: &'a str, + s: &'a Wtf8, end: usize, strict: bool, -) -> Result<(String, usize), DecodeError> { +) -> Result<(Wtf8Buf, usize), DecodeError> { let mut chunks: Vec> = Vec::new(); let mut output_len = 0usize; let mut push_chunk = |chunk: StrOrChar<'a>| { @@ -134,16 +137,16 @@ pub fn scanstring<'a>( chunks.push(chunk); }; let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1); - let mut chars = s.char_indices().enumerate().skip(end).peekable(); + let mut chars = s.code_point_indices().enumerate().skip(end).peekable(); let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?; while let Some((char_i, (byte_i, c))) = chars.next() { - match c { + match c.to_char_lossy() { '"' => { push_chunk(StrOrChar::Str(&s[chunk_start..byte_i])); - let mut out = String::with_capacity(output_len); + let mut out = Wtf8Buf::with_capacity(output_len); for x in chunks { match x { - StrOrChar::Str(s) => out.push_str(s), + StrOrChar::Str(s) => out.push_wtf8(s), StrOrChar::Char(c) => out.push(c), } } @@ -152,7 +155,7 @@ pub fn scanstring<'a>( '\\' => { push_chunk(StrOrChar::Str(&s[chunk_start..byte_i])); let (_, (_, c)) = chars.next().ok_or_else(unterminated_err)?; - let esc = match c { + let esc = match c.to_char_lossy() { '"' => "\"", '\\' => "\\", '/' => "/", @@ -162,41 +165,33 @@ pub fn scanstring<'a>( 'r' => "\r", 't' => "\t", 'u' => { - let surrogate_err = || DecodeError::new("unpaired surrogate", char_i); let mut uni = decode_unicode(&mut chars, char_i)?; chunk_start = byte_i + 6; - if (0xd800..=0xdbff).contains(&uni) { + if let Some(lead) = uni.to_lead_surrogate() { // uni is a surrogate -- try to find its pair - if let Some(&(pos2, (_, '\\'))) = chars.peek() { - // ok, the next char starts an escape - chars.next(); - if let Some((_, (_, 'u'))) = chars.peek() { - // ok, it's a unicode escape - chars.next(); - let uni2 = decode_unicode(&mut chars, pos2)?; + let mut chars2 = chars.clone(); + if let Some(((pos2, _), (_, _))) = chars2 + .next_tuple() + .filter(|((_, (_, c1)), (_, (_, c2)))| *c1 == '\\' && *c2 == 'u') + { + let uni2 = decode_unicode(&mut chars2, pos2)?; + if let Some(trail) = uni2.to_trail_surrogate() { + // ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates + uni = lead.merge(trail).into(); chunk_start = pos2 + 6; - if (0xdc00..=0xdfff).contains(&uni2) { - // ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates - uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00)); - } else { - // if we don't find a matching surrogate, error -- until str - // isn't utf8 internally, we can't parse surrogates - return Err(surrogate_err()); - } - } else { - return Err(surrogate_err()); + chars = chars2; } } } - push_chunk(StrOrChar::Char( - std::char::from_u32(uni).ok_or_else(surrogate_err)?, - )); + push_chunk(StrOrChar::Char(uni)); continue; } - _ => return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)), + _ => { + return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)); + } }; chunk_start = byte_i + 2; - push_chunk(StrOrChar::Str(esc)); + push_chunk(StrOrChar::Str(esc.as_ref())); } '\x00'..='\x1f' if strict => { return Err(DecodeError::new( @@ -211,16 +206,16 @@ pub fn scanstring<'a>( } #[inline] -fn decode_unicode(it: &mut I, pos: usize) -> Result +fn decode_unicode(it: &mut I, pos: usize) -> Result where - I: Iterator, + I: Iterator, { let err = || DecodeError::new("Invalid \\uXXXX escape", pos); let mut uni = 0; for x in (0..4).rev() { let (_, (_, c)) = it.next().ok_or_else(err)?; - let d = c.to_digit(16).ok_or_else(err)?; - uni += d * 16u32.pow(x); + let d = c.to_char().and_then(|c| c.to_digit(16)).ok_or_else(err)? as u16; + uni += d * 16u16.pow(x); } - Ok(uni) + Ok(uni.into()) } diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index e665d1e27..01dd65f51 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -179,9 +179,12 @@ impl Constructor for PyComplex { "complex() can't take second arg if first is a string".to_owned(), )); } - let value = parse_str(s.as_str().trim()).ok_or_else(|| { - vm.new_value_error("complex() arg is a malformed string".to_owned()) - })?; + let value = s + .to_str() + .and_then(|s| parse_str(s.trim())) + .ok_or_else(|| { + vm.new_value_error("complex() arg is a malformed string".to_owned()) + })?; return Self::from(value) .into_ref_with_type(vm, cls) .map(Into::into); diff --git a/vm/src/builtins/float.rs b/vm/src/builtins/float.rs index b4601fbb9..48ccd2c43 100644 --- a/vm/src/builtins/float.rs +++ b/vm/src/builtins/float.rs @@ -161,7 +161,7 @@ impl Constructor for PyFloat { fn float_from_string(val: PyObjectRef, vm: &VirtualMachine) -> PyResult { let (bytearray, buffer, buffer_lock); let b = if let Some(s) = val.payload_if_subclass::(vm) { - s.as_str().trim().as_bytes() + s.as_wtf8().trim().as_bytes() } else if let Some(bytes) = val.payload_if_subclass::(vm) { bytes.as_bytes() } else if let Some(buf) = val.payload_if_subclass::(vm) { diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index aa9613e9d..f457bf5ed 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -847,7 +847,7 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult { - let s = string.as_str().trim(); + let s = string.as_wtf8().trim(); bytes_to_int(s.as_bytes(), base) } bytes @ PyBytes => { diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 8fe390494..dfb9de9ba 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -424,6 +424,23 @@ impl PyStr { self.data.as_str() } + pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> { + self.to_str().ok_or_else(|| { + let start = self + .as_wtf8() + .code_points() + .position(|c| c.to_char().is_none()) + .unwrap(); + vm.new_unicode_encode_error_real( + identifier!(vm, utf_8).to_owned(), + vm.ctx.new_str(self.data.clone()), + start, + start + 1, + vm.ctx.new_str("surrogates not allowed"), + ) + }) + } + pub fn to_string_lossy(&self) -> Cow<'_, str> { self.to_str() .map(Cow::Borrowed) @@ -850,9 +867,9 @@ impl PyStr { /// If the string starts with the prefix string, return string[len(prefix):] /// Otherwise, return a copy of the original string. #[pymethod] - fn removeprefix(&self, pref: PyStrRef) -> String { - self.as_str() - .py_removeprefix(pref.as_str(), pref.byte_len(), |s, p| s.starts_with(p)) + fn removeprefix(&self, pref: PyStrRef) -> Wtf8Buf { + self.as_wtf8() + .py_removeprefix(pref.as_wtf8(), pref.byte_len(), |s, p| s.starts_with(p)) .to_owned() } @@ -861,9 +878,9 @@ impl PyStr { /// If the string ends with the suffix string, return string[:len(suffix)] /// Otherwise, return a copy of the original string. #[pymethod] - fn removesuffix(&self, suffix: PyStrRef) -> String { - self.as_str() - .py_removesuffix(suffix.as_str(), suffix.byte_len(), |s, p| s.ends_with(p)) + fn removesuffix(&self, suffix: PyStrRef) -> Wtf8Buf { + self.as_wtf8() + .py_removesuffix(suffix.as_wtf8(), suffix.byte_len(), |s, p| s.ends_with(p)) .to_owned() } @@ -1294,7 +1311,8 @@ impl PyStr { #[pymethod] fn isidentifier(&self) -> bool { - let mut chars = self.as_str().chars(); + let Some(s) = self.to_str() else { return false }; + let mut chars = s.chars(); let is_identifier_start = chars.next().is_some_and(|c| c == '_' || is_xid_start(c)); // a string is not an identifier if it has whitespace or starts with a number is_identifier_start && chars.all(is_xid_continue) diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 969d6db93..776c777cb 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -884,7 +884,7 @@ impl Constructor for PyType { attributes .entry(identifier!(vm, __qualname__)) - .or_insert_with(|| vm.ctx.new_str(name.as_str()).into()); + .or_insert_with(|| name.clone().into()); if attributes.get(identifier!(vm, __eq__)).is_some() && attributes.get(identifier!(vm, __hash__)).is_none() diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 2b6720e84..dd039c273 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -77,7 +77,7 @@ impl PyObject { )) }) } else if let Some(s) = self.payload::() { - try_convert(self, s.as_str().trim().as_bytes(), vm) + try_convert(self, s.as_wtf8().trim().as_bytes(), vm) } else if let Some(bytes) = self.payload::() { try_convert(self, bytes, vm) } else if let Some(bytearray) = self.payload::() { diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 4e69cf38a..4cdcb6825 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -224,7 +224,7 @@ impl PyObject { dict: Option, vm: &VirtualMachine, ) -> PyResult> { - let name = name_str.as_str(); + let name = name_str.as_wtf8(); let obj_cls = self.class(); let cls_attr_name = vm.ctx.interned_str(name_str); let cls_attr = match cls_attr_name.and_then(|name| obj_cls.get_attr(name)) { diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 6ad2a74f4..320d83968 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -26,7 +26,7 @@ mod _codecs { fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult { vm.state .codec_registry - .lookup(encoding.as_str(), vm) + .lookup(encoding.try_to_str(vm)?, vm) .map(|codec| codec.into_tuple().into()) } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 77d923172..4cf3c058d 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -2245,7 +2245,7 @@ mod _io { let newline = args.newline.unwrap_or_default(); let (encoder, decoder) = - Self::find_coder(&buffer, encoding.as_str(), &errors, newline, vm)?; + Self::find_coder(&buffer, encoding.try_to_str(vm)?, &errors, newline, vm)?; *data = Some(TextIOData { buffer, @@ -2345,7 +2345,7 @@ mod _io { if let Some(encoding) = args.encoding { let (encoder, decoder) = Self::find_coder( &data.buffer, - encoding.as_str(), + encoding.try_to_str(vm)?, &data.errors, data.newline, vm, @@ -3468,9 +3468,9 @@ mod _io { // return the entire contents of the underlying #[pymethod] - fn getvalue(&self, vm: &VirtualMachine) -> PyResult { + fn getvalue(&self, vm: &VirtualMachine) -> PyResult { let bytes = self.buffer(vm)?.getvalue(); - String::from_utf8(bytes) + Wtf8Buf::from_bytes(bytes) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned())) } @@ -3491,10 +3491,10 @@ mod _io { // If k is undefined || k == -1, then we read all bytes until the end of the file. // This also increments the stream position by the value of k #[pymethod] - fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { + fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { let data = self.buffer(vm)?.read(size.to_usize()).unwrap_or_default(); - let value = String::from_utf8(data) + let value = Wtf8Buf::from_bytes(data) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned()))?; Ok(value) } @@ -3505,11 +3505,11 @@ mod _io { } #[pymethod] - fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { + fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult { // TODO size should correspond to the number of characters, at the moments its the number of // bytes. let input = self.buffer(vm)?.readline(size.to_usize(), vm)?; - String::from_utf8(input) + Wtf8Buf::from_bytes(input) .map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned())) } diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index c464dc3ab..f98530e84 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -327,8 +327,12 @@ mod decl { * raises an error if unsupported format is supplied. * If error happens, we set result as input arg. */ - write!(&mut formatted_time, "{}", instant.format(format.as_str())) - .unwrap_or_else(|_| formatted_time = format.to_string()); + write!( + &mut formatted_time, + "{}", + instant.format(format.try_to_str(vm)?) + ) + .unwrap_or_else(|_| formatted_time = format.to_string()); Ok(vm.ctx.new_str(formatted_time).into()) } diff --git a/vm/src/utils.rs b/vm/src/utils.rs index e2bc99368..78edfb71c 100644 --- a/vm/src/utils.rs +++ b/vm/src/utils.rs @@ -1,3 +1,5 @@ +use rustpython_common::wtf8::Wtf8; + use crate::{ PyObjectRef, PyResult, VirtualMachine, builtins::PyStr, @@ -18,9 +20,9 @@ impl ToPyObject for std::convert::Infallible { } } -pub trait ToCString: AsRef { +pub trait ToCString: AsRef { fn to_cstring(&self, vm: &VirtualMachine) -> PyResult { - std::ffi::CString::new(self.as_ref()).map_err(|err| err.to_pyexception(vm)) + std::ffi::CString::new(self.as_ref().as_bytes()).map_err(|err| err.to_pyexception(vm)) } fn ensure_no_nul(&self, vm: &VirtualMachine) -> PyResult<()> { if self.as_ref().as_bytes().contains(&b'\0') {