mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Fix re.split
This commit is contained in:
@@ -14,21 +14,15 @@ use crate::obj::objstr::{PyString, PyStringRef};
|
||||
use crate::obj::objtype::PyClassRef;
|
||||
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromObject};
|
||||
use crate::vm::VirtualMachine;
|
||||
use num_traits::ToPrimitive;
|
||||
use num_traits::{Signed, ToPrimitive};
|
||||
|
||||
// #[derive(Debug)]
|
||||
#[pyclass(name = "Pattern")]
|
||||
#[derive(Debug)]
|
||||
struct PyPattern {
|
||||
regex: Regex,
|
||||
pattern: String,
|
||||
}
|
||||
|
||||
impl fmt::Debug for PyPattern {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Pattern()")
|
||||
}
|
||||
}
|
||||
|
||||
const IGNORECASE: usize = 2;
|
||||
const LOCALE: usize = 4;
|
||||
const MULTILINE: usize = 8;
|
||||
@@ -143,6 +137,18 @@ fn re_findall(
|
||||
do_findall(vm, ®ex, string)
|
||||
}
|
||||
|
||||
fn re_split(
|
||||
pattern: PyStringRef,
|
||||
string: PyStringRef,
|
||||
maxsplit: OptionalArg<PyIntRef>,
|
||||
flags: OptionalArg<PyIntRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
let flags = extract_flags(flags);
|
||||
let regex = make_regex(vm, pattern.as_str(), flags)?;
|
||||
do_split(vm, ®ex, string, maxsplit.into_option())
|
||||
}
|
||||
|
||||
fn do_sub(
|
||||
vm: &VirtualMachine,
|
||||
pattern: &PyPattern,
|
||||
@@ -150,15 +156,12 @@ fn do_sub(
|
||||
search_text: PyStringRef,
|
||||
limit: usize,
|
||||
) -> PyResult {
|
||||
let out = pattern
|
||||
.regex
|
||||
.replacen(
|
||||
search_text.as_str().as_bytes(),
|
||||
limit,
|
||||
repl.as_str().as_bytes(),
|
||||
)
|
||||
.into_owned();
|
||||
let out = unsafe { String::from_utf8_unchecked(out) };
|
||||
let out = pattern.regex.replacen(
|
||||
search_text.as_str().as_bytes(),
|
||||
limit,
|
||||
repl.as_str().as_bytes(),
|
||||
);
|
||||
let out = String::from_utf8_lossy(&out).into_owned();
|
||||
Ok(vm.new_str(out))
|
||||
}
|
||||
|
||||
@@ -208,6 +211,53 @@ fn do_findall(vm: &VirtualMachine, pattern: &PyPattern, search_text: PyStringRef
|
||||
Ok(vm.ctx.new_list(out))
|
||||
}
|
||||
|
||||
fn do_split(
|
||||
vm: &VirtualMachine,
|
||||
pattern: &PyPattern,
|
||||
search_text: PyStringRef,
|
||||
maxsplit: Option<PyIntRef>,
|
||||
) -> PyResult {
|
||||
if maxsplit
|
||||
.as_ref()
|
||||
.map_or(false, |i| i.as_bigint().is_negative())
|
||||
{
|
||||
return Ok(vm.ctx.new_list(vec![search_text.into_object()]));
|
||||
}
|
||||
let maxsplit = maxsplit
|
||||
.map(|i| usize::try_from_object(vm, i.into_object()))
|
||||
.transpose()?
|
||||
.unwrap_or(0);
|
||||
let text = search_text.as_str().as_bytes();
|
||||
// essentially Regex::split, but it outputs captures as well
|
||||
let mut output = Vec::new();
|
||||
let mut last = 0;
|
||||
let mut n = 0;
|
||||
for captures in pattern.regex.captures_iter(text) {
|
||||
let full = captures.get(0).unwrap();
|
||||
let matched = &text[last..full.start()];
|
||||
last = full.end();
|
||||
output.push(Some(matched));
|
||||
for m in captures.iter().skip(1) {
|
||||
output.push(m.map(|m| m.as_bytes()));
|
||||
}
|
||||
n += 1;
|
||||
if maxsplit != 0 && n >= maxsplit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if last < text.len() {
|
||||
output.push(Some(&text[last..]));
|
||||
}
|
||||
let split = output
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
v.map(|v| vm.new_str(String::from_utf8_lossy(v).into_owned()))
|
||||
.unwrap_or_else(|| vm.get_none())
|
||||
})
|
||||
.collect();
|
||||
Ok(vm.ctx.new_list(split))
|
||||
}
|
||||
|
||||
fn make_regex(vm: &VirtualMachine, pattern: &str, flags: PyRegexFlags) -> PyResult<PyPattern> {
|
||||
let unicode = if flags.unicode && flags.ascii {
|
||||
return Err(vm.new_value_error("ASCII and UNICODE flags are incompatible".to_string()));
|
||||
@@ -280,11 +330,8 @@ impl PyPattern {
|
||||
fn sub(&self, repl: PyStringRef, text: PyStringRef, vm: &VirtualMachine) -> PyResult {
|
||||
let replaced_text = self
|
||||
.regex
|
||||
.replace_all(text.value.as_bytes(), repl.as_str().as_bytes())
|
||||
.into_owned();
|
||||
// safe because both the search and replace arguments ^ are unicode strings temporarily
|
||||
// converted to bytes
|
||||
let replaced_text = unsafe { String::from_utf8_unchecked(replaced_text) };
|
||||
.replace_all(text.value.as_bytes(), repl.as_str().as_bytes());
|
||||
let replaced_text = String::from_utf8_lossy(&replaced_text).into_owned();
|
||||
Ok(vm.ctx.new_str(replaced_text))
|
||||
}
|
||||
|
||||
@@ -299,13 +346,13 @@ impl PyPattern {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn split(&self, text: PyStringRef, vm: &VirtualMachine) -> PyObjectRef {
|
||||
let split = self
|
||||
.regex
|
||||
.split(text.as_str().as_bytes())
|
||||
.map(|v| vm.new_str(String::from_utf8_lossy(v).into_owned()))
|
||||
.collect();
|
||||
vm.ctx.new_list(split)
|
||||
fn split(
|
||||
&self,
|
||||
search_text: PyStringRef,
|
||||
maxsplit: OptionalArg<PyIntRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
do_split(vm, self, search_text, maxsplit.into_option())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,6 +454,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"search" => ctx.new_rustfunc(re_search),
|
||||
"sub" => ctx.new_rustfunc(re_sub),
|
||||
"findall" => ctx.new_rustfunc(re_findall),
|
||||
"split" => ctx.new_rustfunc(re_split),
|
||||
"IGNORECASE" => ctx.new_int(IGNORECASE),
|
||||
"I" => ctx.new_int(IGNORECASE),
|
||||
"LOCALE" => ctx.new_int(LOCALE),
|
||||
|
||||
Reference in New Issue
Block a user