From b36b32bfe810ab96cf945ff6ec04f85fcf46dd03 Mon Sep 17 00:00:00 2001 From: Noa Date: Fri, 21 Mar 2025 20:13:15 -0500 Subject: [PATCH] Make re wtf8-compatible --- Cargo.lock | 1 + Lib/test/test_re.py | 2 -- Lib/test/test_smtplib.py | 2 -- vm/src/stdlib/sre.rs | 16 ++++++--- vm/sre_engine/Cargo.toml | 5 +++ vm/sre_engine/src/string.rs | 70 +++++++++++++++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94b0732b8..1fa60eac2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2474,6 +2474,7 @@ dependencies = [ "criterion", "num_enum", "optional", + "rustpython-common", ] [[package]] diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index a060a3dee..f6af44df9 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -854,8 +854,6 @@ class ReTests(unittest.TestCase): # Can match around the whitespace. self.assertEqual(len(re.findall(r"\B", " ")), 2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bigcharset(self): self.assertEqual(re.match("([\u2222\u2223])", "\u2222").group(1), "\u2222") diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index a36d7bbe2..9b787950f 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -1459,8 +1459,6 @@ class SMTPUTF8SimTests(unittest.TestCase): self.assertIn('SMTPUTF8', self.serv.last_mail_options) self.assertEqual(self.serv.last_rcpt_options, []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): msg = EmailMessage() msg['From'] = "Páolo " diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 193976a62..038ac9934 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -9,6 +9,7 @@ mod _sre { PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyListRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, + common::wtf8::{Wtf8, Wtf8Buf}, common::{ascii, hash::PyHash}, convert::ToPyObject, function::{ArgCallable, OptionalArg, PosArgs, PyComparisonValue}, @@ -66,10 +67,15 @@ mod _sre { } } - impl SreStr for &str { + impl SreStr for &Wtf8 { fn slice(&self, start: usize, end: usize, vm: &VirtualMachine) -> PyObjectRef { vm.ctx - .new_str(self.chars().take(end).skip(start).collect::()) + .new_str( + self.code_points() + .take(end) + .skip(start) + .collect::(), + ) .into() } } @@ -206,12 +212,12 @@ mod _sre { impl Pattern { fn with_str(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult where - F: FnOnce(&str) -> PyResult, + F: FnOnce(&Wtf8) -> PyResult, { let string = string.payload::().ok_or_else(|| { vm.new_type_error(format!("expected string got '{}'", string.class())) })?; - f(string.as_str()) + f(string.as_wtf8()) } fn with_bytes(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult @@ -425,7 +431,7 @@ mod _sre { let is_template = if zelf.isbytes { Self::with_bytes(&repl, vm, |x| Ok(x.contains(&b'\\')))? } else { - Self::with_str(&repl, vm, |x| Ok(x.contains('\\')))? + Self::with_str(&repl, vm, |x| Ok(x.contains("\\".as_ref())))? }; if is_template { diff --git a/vm/sre_engine/Cargo.toml b/vm/sre_engine/Cargo.toml index 504652f3a..55ce24990 100644 --- a/vm/sre_engine/Cargo.toml +++ b/vm/sre_engine/Cargo.toml @@ -14,7 +14,12 @@ license.workspace = true name = "benches" harness = false +[features] +default = ["wtf8"] +wtf8 = ["rustpython-common"] + [dependencies] +rustpython-common = { workspace = true, optional = true } num_enum = { workspace = true } bitflags = { workspace = true } optional = "0.5" diff --git a/vm/sre_engine/src/string.rs b/vm/sre_engine/src/string.rs index 77e0f3e77..551b1ca5e 100644 --- a/vm/sre_engine/src/string.rs +++ b/vm/sre_engine/src/string.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "wtf8")] +use rustpython_common::wtf8::Wtf8; + #[derive(Debug, Clone, Copy)] pub struct StringCursor { pub(crate) ptr: *const u8, @@ -148,6 +151,73 @@ impl StrDrive for &str { } } +#[cfg(feature = "wtf8")] +impl StrDrive for &Wtf8 { + #[inline] + fn count(&self) -> usize { + self.code_points().count() + } + + #[inline] + fn create_cursor(&self, n: usize) -> StringCursor { + let mut cursor = StringCursor { + ptr: self.as_bytes().as_ptr(), + position: 0, + }; + Self::skip(&mut cursor, n); + cursor + } + + #[inline] + fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) { + if cursor.ptr.is_null() || cursor.position > n { + *cursor = Self::create_cursor(self, n); + } else if cursor.position < n { + Self::skip(cursor, n - cursor.position); + } + } + + #[inline] + fn advance(cursor: &mut StringCursor) -> u32 { + cursor.position += 1; + unsafe { next_code_point(&mut cursor.ptr) } + } + + #[inline] + fn peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point(&mut ptr) } + } + + #[inline] + fn skip(cursor: &mut StringCursor, n: usize) { + cursor.position += n; + for _ in 0..n { + unsafe { next_code_point(&mut cursor.ptr) }; + } + } + + #[inline] + fn back_advance(cursor: &mut StringCursor) -> u32 { + cursor.position -= 1; + unsafe { next_code_point_reverse(&mut cursor.ptr) } + } + + #[inline] + fn back_peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point_reverse(&mut ptr) } + } + + #[inline] + fn back_skip(cursor: &mut StringCursor, n: usize) { + cursor.position -= n; + for _ in 0..n { + unsafe { next_code_point_reverse(&mut cursor.ptr) }; + } + } +} + /// Reads the next code point out of a byte iterator (assuming a /// UTF-8-like encoding). ///