Make re wtf8-compatible

This commit is contained in:
Noa
2025-03-21 20:13:15 -05:00
parent 3945d3b2fe
commit b36b32bfe8
6 changed files with 87 additions and 9 deletions

1
Cargo.lock generated
View File

@@ -2474,6 +2474,7 @@ dependencies = [
"criterion",
"num_enum",
"optional",
"rustpython-common",
]
[[package]]

2
Lib/test/test_re.py vendored
View File

@@ -854,8 +854,6 @@ class ReTests(unittest.TestCase):
# Can match around the whitespace.
self.assertEqual(len(re.findall(r"\B", " ")), 2)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bigcharset(self):
self.assertEqual(re.match("([\u2222\u2223])",
"\u2222").group(1), "\u2222")

View File

@@ -1459,8 +1459,6 @@ class SMTPUTF8SimTests(unittest.TestCase):
self.assertIn('SMTPUTF8', self.serv.last_mail_options)
self.assertEqual(self.serv.last_rcpt_options, [])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_send_message_uses_smtputf8_if_addrs_non_ascii(self):
msg = EmailMessage()
msg['From'] = "Páolo <főo@bar.com>"

View File

@@ -9,6 +9,7 @@ mod _sre {
PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyListRef, PyStr,
PyStrRef, PyTuple, PyTupleRef, PyTypeRef,
},
common::wtf8::{Wtf8, Wtf8Buf},
common::{ascii, hash::PyHash},
convert::ToPyObject,
function::{ArgCallable, OptionalArg, PosArgs, PyComparisonValue},
@@ -66,10 +67,15 @@ mod _sre {
}
}
impl SreStr for &str {
impl SreStr for &Wtf8 {
fn slice(&self, start: usize, end: usize, vm: &VirtualMachine) -> PyObjectRef {
vm.ctx
.new_str(self.chars().take(end).skip(start).collect::<String>())
.new_str(
self.code_points()
.take(end)
.skip(start)
.collect::<Wtf8Buf>(),
)
.into()
}
}
@@ -206,12 +212,12 @@ mod _sre {
impl Pattern {
fn with_str<F, R>(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult<R>
where
F: FnOnce(&str) -> PyResult<R>,
F: FnOnce(&Wtf8) -> PyResult<R>,
{
let string = string.payload::<PyStr>().ok_or_else(|| {
vm.new_type_error(format!("expected string got '{}'", string.class()))
})?;
f(string.as_str())
f(string.as_wtf8())
}
fn with_bytes<F, R>(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult<R>
@@ -425,7 +431,7 @@ mod _sre {
let is_template = if zelf.isbytes {
Self::with_bytes(&repl, vm, |x| Ok(x.contains(&b'\\')))?
} else {
Self::with_str(&repl, vm, |x| Ok(x.contains('\\')))?
Self::with_str(&repl, vm, |x| Ok(x.contains("\\".as_ref())))?
};
if is_template {

View File

@@ -14,7 +14,12 @@ license.workspace = true
name = "benches"
harness = false
[features]
default = ["wtf8"]
wtf8 = ["rustpython-common"]
[dependencies]
rustpython-common = { workspace = true, optional = true }
num_enum = { workspace = true }
bitflags = { workspace = true }
optional = "0.5"

View File

@@ -1,3 +1,6 @@
#[cfg(feature = "wtf8")]
use rustpython_common::wtf8::Wtf8;
#[derive(Debug, Clone, Copy)]
pub struct StringCursor {
pub(crate) ptr: *const u8,
@@ -148,6 +151,73 @@ impl StrDrive for &str {
}
}
#[cfg(feature = "wtf8")]
impl StrDrive for &Wtf8 {
#[inline]
fn count(&self) -> usize {
self.code_points().count()
}
#[inline]
fn create_cursor(&self, n: usize) -> StringCursor {
let mut cursor = StringCursor {
ptr: self.as_bytes().as_ptr(),
position: 0,
};
Self::skip(&mut cursor, n);
cursor
}
#[inline]
fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) {
if cursor.ptr.is_null() || cursor.position > n {
*cursor = Self::create_cursor(self, n);
} else if cursor.position < n {
Self::skip(cursor, n - cursor.position);
}
}
#[inline]
fn advance(cursor: &mut StringCursor) -> u32 {
cursor.position += 1;
unsafe { next_code_point(&mut cursor.ptr) }
}
#[inline]
fn peek(cursor: &StringCursor) -> u32 {
let mut ptr = cursor.ptr;
unsafe { next_code_point(&mut ptr) }
}
#[inline]
fn skip(cursor: &mut StringCursor, n: usize) {
cursor.position += n;
for _ in 0..n {
unsafe { next_code_point(&mut cursor.ptr) };
}
}
#[inline]
fn back_advance(cursor: &mut StringCursor) -> u32 {
cursor.position -= 1;
unsafe { next_code_point_reverse(&mut cursor.ptr) }
}
#[inline]
fn back_peek(cursor: &StringCursor) -> u32 {
let mut ptr = cursor.ptr;
unsafe { next_code_point_reverse(&mut ptr) }
}
#[inline]
fn back_skip(cursor: &mut StringCursor, n: usize) {
cursor.position -= n;
for _ in 0..n {
unsafe { next_code_point_reverse(&mut cursor.ptr) };
}
}
}
/// Reads the next code point out of a byte iterator (assuming a
/// UTF-8-like encoding).
///