Fix more surrogate crashes

This commit is contained in:
Noa
2025-03-26 20:37:26 -05:00
parent c6cab4c43a
commit 0a07cd931f
23 changed files with 142 additions and 140 deletions

View File

@@ -869,6 +869,11 @@ class UTF16Test(ReadTest, unittest.TestCase):
with reader:
self.assertEqual(reader.read(), s1)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_incremental_surrogatepass(self):
super().test_incremental_surrogatepass()
class UTF16LETest(ReadTest, unittest.TestCase):
encoding = "utf-16-le"
ill_formed_sequence = b"\x80\xdc"
@@ -917,6 +922,11 @@ class UTF16LETest(ReadTest, unittest.TestCase):
self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding),
"\U00010203")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_incremental_surrogatepass(self):
super().test_incremental_surrogatepass()
class UTF16BETest(ReadTest, unittest.TestCase):
encoding = "utf-16-be"
ill_formed_sequence = b"\xdc\x80"
@@ -965,6 +975,11 @@ class UTF16BETest(ReadTest, unittest.TestCase):
self.assertEqual(b'\xd8\x00\xde\x03'.decode(self.encoding),
"\U00010203")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_incremental_surrogatepass(self):
super().test_incremental_surrogatepass()
class UTF8Test(ReadTest, unittest.TestCase):
encoding = "utf-8"
ill_formed_sequence = b"\xed\xb2\x80"
@@ -998,8 +1013,6 @@ class UTF8Test(ReadTest, unittest.TestCase):
self.check_state_handling_decode(self.encoding,
u, u.encode(self.encoding))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_decode_error(self):
for data, error_handler, expected in (
(b'[\x80\xff]', 'ignore', '[]'),
@@ -1026,8 +1039,6 @@ class UTF8Test(ReadTest, unittest.TestCase):
exc = cm.exception
self.assertEqual(exc.object[exc.start:exc.end], '\uD800\uDFFF')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_surrogatepass_handler(self):
self.assertEqual("abc\ud800def".encode(self.encoding, "surrogatepass"),
self.BOM + b"abc\xed\xa0\x80def")
@@ -2884,8 +2895,6 @@ class EscapeEncodeTest(unittest.TestCase):
class SurrogateEscapeTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_utf8(self):
# Bad byte
self.assertEqual(b"foo\x80bar".decode("utf-8", "surrogateescape"),
@@ -2898,8 +2907,6 @@ class SurrogateEscapeTest(unittest.TestCase):
self.assertEqual("\udced\udcb0\udc80".encode("utf-8", "surrogateescape"),
b"\xed\xb0\x80")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_ascii(self):
# bad byte
self.assertEqual(b"foo\x80bar".decode("ascii", "surrogateescape"),
@@ -2916,8 +2923,6 @@ class SurrogateEscapeTest(unittest.TestCase):
self.assertEqual("foo\udca5bar".encode("iso-8859-3", "surrogateescape"),
b"foo\xa5bar")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_latin1(self):
# Issue6373
self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin-1", "surrogateescape"),
@@ -3561,8 +3566,6 @@ class ASCIITest(unittest.TestCase):
def test_encode(self):
self.assertEqual('abc123'.encode('ascii'), b'abc123')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_encode_error(self):
for data, error_handler, expected in (
('[\x80\xff\u20ac]', 'ignore', b'[]'),
@@ -3585,8 +3588,6 @@ class ASCIITest(unittest.TestCase):
def test_decode(self):
self.assertEqual(b'abc'.decode('ascii'), 'abc')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_decode_error(self):
for data, error_handler, expected in (
(b'[\x80\xff]', 'ignore', '[]'),
@@ -3609,8 +3610,6 @@ class Latin1Test(unittest.TestCase):
with self.subTest(data=data, expected=expected):
self.assertEqual(data.encode('latin1'), expected)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_encode_errors(self):
for data, error_handler, expected in (
('[\u20ac\udc80]', 'ignore', b'[]'),

View File

@@ -86,8 +86,6 @@ class TestScanstring:
scanstring('["Bad value", truth]', 2, True),
('Bad value', 12))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_surrogates(self):
scanstring = self.json.decoder.scanstring
def assertScan(given, expect):

View File

@@ -945,7 +945,6 @@ class ArgsTestCase(BaseTestCase):
""")
self.check_leak(code, 'file descriptors')
@unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows')
def test_list_tests(self):
# test --list-tests
tests = [self.create_test() for i in range(5)]
@@ -953,7 +952,6 @@ class ArgsTestCase(BaseTestCase):
self.assertEqual(output.rstrip().splitlines(),
tests)
@unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows')
def test_list_cases(self):
# test --list-cases
code = textwrap.dedent("""

View File

@@ -6,8 +6,6 @@ import unittest
from stringprep import *
class StringprepTests(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test(self):
self.assertTrue(in_table_a1("\u0221"))
self.assertFalse(in_table_a1("\u0222"))

View File

@@ -1198,8 +1198,6 @@ class ProcessTestCase(BaseTestCase):
stdout, stderr = popen.communicate(input='')
self.assertEqual(stdout, '1\n2\n3\n4')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_communicate_errors(self):
for errors, expected in [
('ignore', ''),

View File

@@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase):
format = tarfile.USTAR_FORMAT
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_uname_unicode(self):
super().test_uname_unicode()
# Test whether the utf-8 encoded version of a filename exceeds the 100
# bytes name field limit (every occurrence of '\xff' will be expanded to 2
# bytes).
@@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase):
format = tarfile.GNU_FORMAT
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_uname_unicode(self):
super().test_uname_unicode()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bad_pax_header(self):
# Test for issue #8633. GNU tar <= 1.23 creates raw binary fields
# without a hdrcharset=BINARY header.
@@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase):
# PAX_FORMAT ignores encoding in write mode.
test_unicode_filename_error = None
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_binary_header(self):
# Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field.
for encoding, name in (

View File

@@ -608,8 +608,6 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('abc' == bytearray(b'abc'), False)
self.assertEqual('abc' != bytearray(b'abc'), True)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_comparison(self):
# Comparisons:
self.assertEqual('abc', 'abc')
@@ -830,8 +828,6 @@ class UnicodeTest(string_tests.CommonTest,
warnings.simplefilter('ignore', DeprecationWarning)
self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier())
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_isprintable(self):
self.assertTrue("".isprintable())
self.assertTrue(" ".isprintable())
@@ -847,8 +843,6 @@ class UnicodeTest(string_tests.CommonTest,
self.assertTrue('\U0001F46F'.isprintable())
self.assertFalse('\U000E0020'.isprintable())
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_surrogates(self):
for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800',
'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'):
@@ -1827,8 +1821,6 @@ class UnicodeTest(string_tests.CommonTest,
'ill-formed sequence'):
b'+@'.decode('utf-7')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_codecs_utf8(self):
self.assertEqual(''.encode('utf-8'), b'')
self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac')

View File

@@ -53,8 +53,6 @@ class UserStringTest(
str3 = ustr3('TEST')
self.assertEqual(fmt2 % str3, 'value is TEST')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_encode_default_args(self):
self.checkequal(b'hello', 'hello', 'encode')
# Check that encoding defaults to utf-8
@@ -62,8 +60,6 @@ class UserStringTest(
# Check that errors defaults to 'strict'
self.checkraises(UnicodeError, '\ud800', 'encode')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_encode_explicit_none_args(self):
self.checkequal(b'hello', 'hello', 'encode', None, None)
# Check that encoding defaults to utf-8

View File

@@ -730,6 +730,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
@unittest.skipIf(os_helper.TESTFN_UNENCODABLE is None,
"need an unencodable filename")
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def testUnencodable(self):
filename = os_helper.TESTFN_UNENCODABLE + ".zip"
self.addCleanup(os_helper.unlink, filename)

View File

@@ -122,18 +122,18 @@ impl CodePoint {
/// Returns the numeric value of the code point if it is a leading surrogate.
#[inline]
pub fn to_lead_surrogate(self) -> Option<u16> {
pub fn to_lead_surrogate(self) -> Option<LeadSurrogate> {
match self.value {
lead @ 0xD800..=0xDBFF => Some(lead as u16),
lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)),
_ => None,
}
}
/// Returns the numeric value of the code point if it is a trailing surrogate.
#[inline]
pub fn to_trail_surrogate(self) -> Option<u16> {
pub fn to_trail_surrogate(self) -> Option<TrailSurrogate> {
match self.value {
trail @ 0xDC00..=0xDFFF => Some(trail as u16),
trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)),
_ => None,
}
}
@@ -216,6 +216,18 @@ impl PartialEq<CodePoint> for char {
}
}
#[derive(Clone, Copy)]
pub struct LeadSurrogate(u16);
#[derive(Clone, Copy)]
pub struct TrailSurrogate(u16);
impl LeadSurrogate {
pub fn merge(self, trail: TrailSurrogate) -> char {
decode_surrogate_pair(self.0, trail.0)
}
}
/// An owned, growable string of well-formed WTF-8 data.
///
/// Similar to `String`, but can additionally contain surrogate code points
@@ -291,6 +303,14 @@ impl Wtf8Buf {
Wtf8Buf { bytes: value }
}
/// Create a WTF-8 string from a WTF-8 byte vec.
pub fn from_bytes(value: Vec<u8>) -> Result<Self, Vec<u8>> {
match Wtf8::from_bytes(&value) {
Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }),
None => Err(value),
}
}
/// Creates a WTF-8 string from a UTF-8 `String`.
///
/// This takes ownership of the `String` and does not copy.
@@ -750,15 +770,10 @@ impl Wtf8 {
}
fn decode_surrogate(b: &[u8]) -> Option<CodePoint> {
let [a, b, c, ..] = *b else { return None };
if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 {
// it's a three-byte code
let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f);
let 0xD800..=0xDFFF = c else { return None };
Some(CodePoint { value: c })
} else {
None
}
let [0xed, b2 @ (0xa0..), b3, ..] = *b else {
return None;
};
Some(decode_surrogate(b2, b3).into())
}
/// Returns the length, in WTF-8 bytes.
@@ -914,14 +929,6 @@ impl Wtf8 {
}
}
#[inline]
fn final_lead_surrogate(&self) -> Option<u16> {
match self.bytes {
[.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)),
_ => None,
}
}
pub fn is_code_point_boundary(&self, index: usize) -> bool {
is_code_point_boundary(self, index)
}
@@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 {
0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F
}
#[inline]
fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32);
unsafe { char::from_u32_unchecked(code_point) }
}
/// Copied from str::is_char_boundary
#[inline]
fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {

View File

@@ -13,6 +13,7 @@ mod _json {
types::{Callable, Constructor},
};
use malachite_bigint::BigInt;
use rustpython_common::wtf8::Wtf8Buf;
use std::str::FromStr;
#[pyattr(name = "make_scanner")]
@@ -253,8 +254,8 @@ mod _json {
end: usize,
strict: OptionalArg<bool>,
vm: &VirtualMachine,
) -> PyResult<(String, usize)> {
machinery::scanstring(s.as_str(), end, strict.unwrap_or(true))
) -> PyResult<(Wtf8Buf, usize)> {
machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true))
.map_err(|e| py_decode_error(e, s, vm))
}
}

View File

@@ -28,6 +28,9 @@
use std::io;
use itertools::Itertools;
use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf};
static ESCAPE_CHARS: [&str; 0x20] = [
"\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b",
"\\t", "\\n", "\\u000", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012",
@@ -111,22 +114,22 @@ impl DecodeError {
}
enum StrOrChar<'a> {
Str(&'a str),
Char(char),
Str(&'a Wtf8),
Char(CodePoint),
}
impl StrOrChar<'_> {
fn len(&self) -> usize {
match self {
StrOrChar::Str(s) => s.len(),
StrOrChar::Char(c) => c.len_utf8(),
StrOrChar::Char(c) => c.len_wtf8(),
}
}
}
pub fn scanstring<'a>(
s: &'a str,
s: &'a Wtf8,
end: usize,
strict: bool,
) -> Result<(String, usize), DecodeError> {
) -> Result<(Wtf8Buf, usize), DecodeError> {
let mut chunks: Vec<StrOrChar<'a>> = Vec::new();
let mut output_len = 0usize;
let mut push_chunk = |chunk: StrOrChar<'a>| {
@@ -134,16 +137,16 @@ pub fn scanstring<'a>(
chunks.push(chunk);
};
let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1);
let mut chars = s.char_indices().enumerate().skip(end).peekable();
let mut chars = s.code_point_indices().enumerate().skip(end).peekable();
let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?;
while let Some((char_i, (byte_i, c))) = chars.next() {
match c {
match c.to_char_lossy() {
'"' => {
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
let mut out = String::with_capacity(output_len);
let mut out = Wtf8Buf::with_capacity(output_len);
for x in chunks {
match x {
StrOrChar::Str(s) => out.push_str(s),
StrOrChar::Str(s) => out.push_wtf8(s),
StrOrChar::Char(c) => out.push(c),
}
}
@@ -152,7 +155,7 @@ pub fn scanstring<'a>(
'\\' => {
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
let (_, (_, c)) = chars.next().ok_or_else(unterminated_err)?;
let esc = match c {
let esc = match c.to_char_lossy() {
'"' => "\"",
'\\' => "\\",
'/' => "/",
@@ -162,41 +165,33 @@ pub fn scanstring<'a>(
'r' => "\r",
't' => "\t",
'u' => {
let surrogate_err = || DecodeError::new("unpaired surrogate", char_i);
let mut uni = decode_unicode(&mut chars, char_i)?;
chunk_start = byte_i + 6;
if (0xd800..=0xdbff).contains(&uni) {
if let Some(lead) = uni.to_lead_surrogate() {
// uni is a surrogate -- try to find its pair
if let Some(&(pos2, (_, '\\'))) = chars.peek() {
// ok, the next char starts an escape
chars.next();
if let Some((_, (_, 'u'))) = chars.peek() {
// ok, it's a unicode escape
chars.next();
let uni2 = decode_unicode(&mut chars, pos2)?;
let mut chars2 = chars.clone();
if let Some(((pos2, _), (_, _))) = chars2
.next_tuple()
.filter(|((_, (_, c1)), (_, (_, c2)))| *c1 == '\\' && *c2 == 'u')
{
let uni2 = decode_unicode(&mut chars2, pos2)?;
if let Some(trail) = uni2.to_trail_surrogate() {
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
uni = lead.merge(trail).into();
chunk_start = pos2 + 6;
if (0xdc00..=0xdfff).contains(&uni2) {
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00));
} else {
// if we don't find a matching surrogate, error -- until str
// isn't utf8 internally, we can't parse surrogates
return Err(surrogate_err());
}
} else {
return Err(surrogate_err());
chars = chars2;
}
}
}
push_chunk(StrOrChar::Char(
std::char::from_u32(uni).ok_or_else(surrogate_err)?,
));
push_chunk(StrOrChar::Char(uni));
continue;
}
_ => return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)),
_ => {
return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i));
}
};
chunk_start = byte_i + 2;
push_chunk(StrOrChar::Str(esc));
push_chunk(StrOrChar::Str(esc.as_ref()));
}
'\x00'..='\x1f' if strict => {
return Err(DecodeError::new(
@@ -211,16 +206,16 @@ pub fn scanstring<'a>(
}
#[inline]
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<u32, DecodeError>
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<CodePoint, DecodeError>
where
I: Iterator<Item = (usize, (usize, char))>,
I: Iterator<Item = (usize, (usize, CodePoint))>,
{
let err = || DecodeError::new("Invalid \\uXXXX escape", pos);
let mut uni = 0;
for x in (0..4).rev() {
let (_, (_, c)) = it.next().ok_or_else(err)?;
let d = c.to_digit(16).ok_or_else(err)?;
uni += d * 16u32.pow(x);
let d = c.to_char().and_then(|c| c.to_digit(16)).ok_or_else(err)? as u16;
uni += d * 16u16.pow(x);
}
Ok(uni)
Ok(uni.into())
}

View File

@@ -179,9 +179,12 @@ impl Constructor for PyComplex {
"complex() can't take second arg if first is a string".to_owned(),
));
}
let value = parse_str(s.as_str().trim()).ok_or_else(|| {
vm.new_value_error("complex() arg is a malformed string".to_owned())
})?;
let value = s
.to_str()
.and_then(|s| parse_str(s.trim()))
.ok_or_else(|| {
vm.new_value_error("complex() arg is a malformed string".to_owned())
})?;
return Self::from(value)
.into_ref_with_type(vm, cls)
.map(Into::into);

View File

@@ -161,7 +161,7 @@ impl Constructor for PyFloat {
fn float_from_string(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<f64> {
let (bytearray, buffer, buffer_lock);
let b = if let Some(s) = val.payload_if_subclass::<PyStr>(vm) {
s.as_str().trim().as_bytes()
s.as_wtf8().trim().as_bytes()
} else if let Some(bytes) = val.payload_if_subclass::<PyBytes>(vm) {
bytes.as_bytes()
} else if let Some(buf) = val.payload_if_subclass::<PyByteArray>(vm) {

View File

@@ -847,7 +847,7 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult<Big
let opt = match_class!(match obj.to_owned() {
string @ PyStr => {
let s = string.as_str().trim();
let s = string.as_wtf8().trim();
bytes_to_int(s.as_bytes(), base)
}
bytes @ PyBytes => {

View File

@@ -424,6 +424,23 @@ impl PyStr {
self.data.as_str()
}
pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
self.to_str().ok_or_else(|| {
let start = self
.as_wtf8()
.code_points()
.position(|c| c.to_char().is_none())
.unwrap();
vm.new_unicode_encode_error_real(
identifier!(vm, utf_8).to_owned(),
vm.ctx.new_str(self.data.clone()),
start,
start + 1,
vm.ctx.new_str("surrogates not allowed"),
)
})
}
pub fn to_string_lossy(&self) -> Cow<'_, str> {
self.to_str()
.map(Cow::Borrowed)
@@ -850,9 +867,9 @@ impl PyStr {
/// If the string starts with the prefix string, return string[len(prefix):]
/// Otherwise, return a copy of the original string.
#[pymethod]
fn removeprefix(&self, pref: PyStrRef) -> String {
self.as_str()
.py_removeprefix(pref.as_str(), pref.byte_len(), |s, p| s.starts_with(p))
fn removeprefix(&self, pref: PyStrRef) -> Wtf8Buf {
self.as_wtf8()
.py_removeprefix(pref.as_wtf8(), pref.byte_len(), |s, p| s.starts_with(p))
.to_owned()
}
@@ -861,9 +878,9 @@ impl PyStr {
/// If the string ends with the suffix string, return string[:len(suffix)]
/// Otherwise, return a copy of the original string.
#[pymethod]
fn removesuffix(&self, suffix: PyStrRef) -> String {
self.as_str()
.py_removesuffix(suffix.as_str(), suffix.byte_len(), |s, p| s.ends_with(p))
fn removesuffix(&self, suffix: PyStrRef) -> Wtf8Buf {
self.as_wtf8()
.py_removesuffix(suffix.as_wtf8(), suffix.byte_len(), |s, p| s.ends_with(p))
.to_owned()
}
@@ -1294,7 +1311,8 @@ impl PyStr {
#[pymethod]
fn isidentifier(&self) -> bool {
let mut chars = self.as_str().chars();
let Some(s) = self.to_str() else { return false };
let mut chars = s.chars();
let is_identifier_start = chars.next().is_some_and(|c| c == '_' || is_xid_start(c));
// a string is not an identifier if it has whitespace or starts with a number
is_identifier_start && chars.all(is_xid_continue)

View File

@@ -884,7 +884,7 @@ impl Constructor for PyType {
attributes
.entry(identifier!(vm, __qualname__))
.or_insert_with(|| vm.ctx.new_str(name.as_str()).into());
.or_insert_with(|| name.clone().into());
if attributes.get(identifier!(vm, __eq__)).is_some()
&& attributes.get(identifier!(vm, __hash__)).is_none()

View File

@@ -77,7 +77,7 @@ impl PyObject {
))
})
} else if let Some(s) = self.payload::<PyStr>() {
try_convert(self, s.as_str().trim().as_bytes(), vm)
try_convert(self, s.as_wtf8().trim().as_bytes(), vm)
} else if let Some(bytes) = self.payload::<PyBytes>() {
try_convert(self, bytes, vm)
} else if let Some(bytearray) = self.payload::<PyByteArray>() {

View File

@@ -224,7 +224,7 @@ impl PyObject {
dict: Option<PyDictRef>,
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
let name = name_str.as_str();
let name = name_str.as_wtf8();
let obj_cls = self.class();
let cls_attr_name = vm.ctx.interned_str(name_str);
let cls_attr = match cls_attr_name.and_then(|name| obj_cls.get_attr(name)) {

View File

@@ -26,7 +26,7 @@ mod _codecs {
fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult {
vm.state
.codec_registry
.lookup(encoding.as_str(), vm)
.lookup(encoding.try_to_str(vm)?, vm)
.map(|codec| codec.into_tuple().into())
}

View File

@@ -2245,7 +2245,7 @@ mod _io {
let newline = args.newline.unwrap_or_default();
let (encoder, decoder) =
Self::find_coder(&buffer, encoding.as_str(), &errors, newline, vm)?;
Self::find_coder(&buffer, encoding.try_to_str(vm)?, &errors, newline, vm)?;
*data = Some(TextIOData {
buffer,
@@ -2345,7 +2345,7 @@ mod _io {
if let Some(encoding) = args.encoding {
let (encoder, decoder) = Self::find_coder(
&data.buffer,
encoding.as_str(),
encoding.try_to_str(vm)?,
&data.errors,
data.newline,
vm,
@@ -3468,9 +3468,9 @@ mod _io {
// return the entire contents of the underlying
#[pymethod]
fn getvalue(&self, vm: &VirtualMachine) -> PyResult<String> {
fn getvalue(&self, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let bytes = self.buffer(vm)?.getvalue();
String::from_utf8(bytes)
Wtf8Buf::from_bytes(bytes)
.map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned()))
}
@@ -3491,10 +3491,10 @@ mod _io {
// If k is undefined || k == -1, then we read all bytes until the end of the file.
// This also increments the stream position by the value of k
#[pymethod]
fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult<String> {
fn read(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
let data = self.buffer(vm)?.read(size.to_usize()).unwrap_or_default();
let value = String::from_utf8(data)
let value = Wtf8Buf::from_bytes(data)
.map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned()))?;
Ok(value)
}
@@ -3505,11 +3505,11 @@ mod _io {
}
#[pymethod]
fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult<String> {
fn readline(&self, size: OptionalSize, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
// TODO size should correspond to the number of characters, at the moments its the number of
// bytes.
let input = self.buffer(vm)?.readline(size.to_usize(), vm)?;
String::from_utf8(input)
Wtf8Buf::from_bytes(input)
.map_err(|_| vm.new_value_error("Error Retrieving Value".to_owned()))
}

View File

@@ -327,8 +327,12 @@ mod decl {
* raises an error if unsupported format is supplied.
* If error happens, we set result as input arg.
*/
write!(&mut formatted_time, "{}", instant.format(format.as_str()))
.unwrap_or_else(|_| formatted_time = format.to_string());
write!(
&mut formatted_time,
"{}",
instant.format(format.try_to_str(vm)?)
)
.unwrap_or_else(|_| formatted_time = format.to_string());
Ok(vm.ctx.new_str(formatted_time).into())
}

View File

@@ -1,3 +1,5 @@
use rustpython_common::wtf8::Wtf8;
use crate::{
PyObjectRef, PyResult, VirtualMachine,
builtins::PyStr,
@@ -18,9 +20,9 @@ impl ToPyObject for std::convert::Infallible {
}
}
pub trait ToCString: AsRef<str> {
pub trait ToCString: AsRef<Wtf8> {
fn to_cstring(&self, vm: &VirtualMachine) -> PyResult<std::ffi::CString> {
std::ffi::CString::new(self.as_ref()).map_err(|err| err.to_pyexception(vm))
std::ffi::CString::new(self.as_ref().as_bytes()).map_err(|err| err.to_pyexception(vm))
}
fn ensure_no_nul(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.as_ref().as_bytes().contains(&b'\0') {