diff --git a/vm/src/stdlib/re.rs b/vm/src/stdlib/re.rs index 94ea2c27b..b44f5701d 100644 --- a/vm/src/stdlib/re.rs +++ b/vm/src/stdlib/re.rs @@ -14,21 +14,15 @@ use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromObject}; use crate::vm::VirtualMachine; -use num_traits::ToPrimitive; +use num_traits::{Signed, ToPrimitive}; -// #[derive(Debug)] #[pyclass(name = "Pattern")] +#[derive(Debug)] struct PyPattern { regex: Regex, pattern: String, } -impl fmt::Debug for PyPattern { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Pattern()") - } -} - const IGNORECASE: usize = 2; const LOCALE: usize = 4; const MULTILINE: usize = 8; @@ -143,6 +137,18 @@ fn re_findall( do_findall(vm, ®ex, string) } +fn re_split( + pattern: PyStringRef, + string: PyStringRef, + maxsplit: OptionalArg, + flags: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let flags = extract_flags(flags); + let regex = make_regex(vm, pattern.as_str(), flags)?; + do_split(vm, ®ex, string, maxsplit.into_option()) +} + fn do_sub( vm: &VirtualMachine, pattern: &PyPattern, @@ -150,15 +156,12 @@ fn do_sub( search_text: PyStringRef, limit: usize, ) -> PyResult { - let out = pattern - .regex - .replacen( - search_text.as_str().as_bytes(), - limit, - repl.as_str().as_bytes(), - ) - .into_owned(); - let out = unsafe { String::from_utf8_unchecked(out) }; + let out = pattern.regex.replacen( + search_text.as_str().as_bytes(), + limit, + repl.as_str().as_bytes(), + ); + let out = String::from_utf8_lossy(&out).into_owned(); Ok(vm.new_str(out)) } @@ -208,6 +211,53 @@ fn do_findall(vm: &VirtualMachine, pattern: &PyPattern, search_text: PyStringRef Ok(vm.ctx.new_list(out)) } +fn do_split( + vm: &VirtualMachine, + pattern: &PyPattern, + search_text: PyStringRef, + maxsplit: Option, +) -> PyResult { + if maxsplit + .as_ref() + .map_or(false, |i| i.as_bigint().is_negative()) + { + return Ok(vm.ctx.new_list(vec![search_text.into_object()])); + } + let maxsplit = maxsplit + .map(|i| usize::try_from_object(vm, i.into_object())) + .transpose()? + .unwrap_or(0); + let text = search_text.as_str().as_bytes(); + // essentially Regex::split, but it outputs captures as well + let mut output = Vec::new(); + let mut last = 0; + let mut n = 0; + for captures in pattern.regex.captures_iter(text) { + let full = captures.get(0).unwrap(); + let matched = &text[last..full.start()]; + last = full.end(); + output.push(Some(matched)); + for m in captures.iter().skip(1) { + output.push(m.map(|m| m.as_bytes())); + } + n += 1; + if maxsplit != 0 && n >= maxsplit { + break; + } + } + if last < text.len() { + output.push(Some(&text[last..])); + } + let split = output + .into_iter() + .map(|v| { + v.map(|v| vm.new_str(String::from_utf8_lossy(v).into_owned())) + .unwrap_or_else(|| vm.get_none()) + }) + .collect(); + Ok(vm.ctx.new_list(split)) +} + fn make_regex(vm: &VirtualMachine, pattern: &str, flags: PyRegexFlags) -> PyResult { let unicode = if flags.unicode && flags.ascii { return Err(vm.new_value_error("ASCII and UNICODE flags are incompatible".to_string())); @@ -280,11 +330,8 @@ impl PyPattern { fn sub(&self, repl: PyStringRef, text: PyStringRef, vm: &VirtualMachine) -> PyResult { let replaced_text = self .regex - .replace_all(text.value.as_bytes(), repl.as_str().as_bytes()) - .into_owned(); - // safe because both the search and replace arguments ^ are unicode strings temporarily - // converted to bytes - let replaced_text = unsafe { String::from_utf8_unchecked(replaced_text) }; + .replace_all(text.value.as_bytes(), repl.as_str().as_bytes()); + let replaced_text = String::from_utf8_lossy(&replaced_text).into_owned(); Ok(vm.ctx.new_str(replaced_text)) } @@ -299,13 +346,13 @@ impl PyPattern { } #[pymethod] - fn split(&self, text: PyStringRef, vm: &VirtualMachine) -> PyObjectRef { - let split = self - .regex - .split(text.as_str().as_bytes()) - .map(|v| vm.new_str(String::from_utf8_lossy(v).into_owned())) - .collect(); - vm.ctx.new_list(split) + fn split( + &self, + search_text: PyStringRef, + maxsplit: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + do_split(vm, self, search_text, maxsplit.into_option()) } } @@ -407,6 +454,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "search" => ctx.new_rustfunc(re_search), "sub" => ctx.new_rustfunc(re_sub), "findall" => ctx.new_rustfunc(re_findall), + "split" => ctx.new_rustfunc(re_split), "IGNORECASE" => ctx.new_int(IGNORECASE), "I" => ctx.new_int(IGNORECASE), "LOCALE" => ctx.new_int(LOCALE),