diff --git a/vm/src/stdlib/re.rs b/vm/src/stdlib/re.rs index ca8eddce7..209246e13 100644 --- a/vm/src/stdlib/re.rs +++ b/vm/src/stdlib/re.rs @@ -12,7 +12,7 @@ use crate::function::{Args, OptionalArg}; use crate::obj::objint::{PyInt, PyIntRef}; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromObject}; use crate::vm::VirtualMachine; use num_traits::ToPrimitive; @@ -118,6 +118,54 @@ fn re_search( do_search(vm, ®ex, string) } +fn re_sub( + pattern: PyStringRef, + repl: PyStringRef, + string: PyStringRef, + count: OptionalArg, + flags: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let flags = extract_flags(flags); + let regex = make_regex(vm, pattern.as_str(), flags)?; + let limit = count + .into_option() + .map(|i| usize::try_from_object(vm, i.into_object())) + .transpose()? + .unwrap_or(0); + do_sub(vm, ®ex, repl, string, limit) +} + +fn re_findall( + pattern: PyStringRef, + string: PyStringRef, + flags: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let flags = extract_flags(flags); + let regex = make_regex(vm, pattern.as_str(), flags)?; + do_findall(vm, ®ex, string) +} + +fn do_sub( + vm: &VirtualMachine, + pattern: &PyPattern, + repl: PyStringRef, + search_text: PyStringRef, + limit: usize, +) -> PyResult { + let out = pattern + .regex + .replacen( + search_text.as_str().as_bytes(), + limit, + repl.as_str().as_bytes(), + ) + .into_owned(); + let out = unsafe { String::from_utf8_unchecked(out) }; + Ok(vm.new_str(out)) +} + fn do_match(vm: &VirtualMachine, pattern: &PyPattern, search_text: PyStringRef) -> PyResult { // I really wish there was a better way to do this; I don't think there is let mut regex = r"\A".to_owned(); @@ -137,6 +185,33 @@ fn do_search(vm: &VirtualMachine, regex: &PyPattern, search_text: PyStringRef) - } } +fn do_findall(vm: &VirtualMachine, pattern: &PyPattern, search_text: PyStringRef) -> PyResult { + let out = pattern + .regex + .captures_iter(search_text.as_str().as_bytes()) + .map(|captures| { + if captures.len() == 1 { + let full = captures.get(1).unwrap().as_bytes(); + let full = String::from_utf8_lossy(full).into_owned(); + vm.new_str(full) + } else { + let out = captures + .iter() + .skip(1) + .map(|m| { + let s = m + .map(|m| String::from_utf8_lossy(m.as_bytes()).into_owned()) + .unwrap_or_default(); + vm.ctx.new_str(s) + }) + .collect(); + vm.ctx.new_tuple(out) + } + }) + .collect(); + Ok(vm.ctx.new_list(out)) +} + fn make_regex(vm: &VirtualMachine, pattern: &str, flags: PyRegexFlags) -> PyResult { let unicode = if flags.unicode && flags.ascii { return Err(vm.new_value_error("ASCII and UNICODE flags are incompatible".to_string())); @@ -150,7 +225,10 @@ fn make_regex(vm: &VirtualMachine, pattern: &str, flags: PyRegexFlags) -> PyResu .ignore_whitespace(flags.verbose) .unicode(unicode) .build() - .map_err(|err| vm.new_value_error(format!("Error in regex: {:?}", err)))?; + .map_err(|err| match err { + regex::Error::Syntax(s) => vm.new_value_error(format!("Error in regex: {}", s)), + err => vm.new_value_error(format!("Error in regex: {:?}", err)), + })?; Ok(PyPattern { regex: r, pattern: pattern.to_string(), @@ -252,9 +330,7 @@ impl PyMatch { fn get_bounds(&self, id: PyObjectRef, vm: &VirtualMachine) -> PyResult> { match_class!(id, i @ PyInt => { - let i = i.as_bigint().to_usize().ok_or_else(|| { - vm.new_overflow_error("Cannot fit index into rust usize".to_owned()) - })?; + let i = usize::try_from_object(vm,i.into_object())?; match self.captures.get(i) { None => Err(vm.new_index_error("No such group".to_owned())), Some(None) => Ok(None), @@ -323,6 +399,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "match" => ctx.new_rustfunc(re_match), "Pattern" => pattern_type, "search" => ctx.new_rustfunc(re_search), + "sub" => ctx.new_rustfunc(re_sub), + "findall" => ctx.new_rustfunc(re_findall), "IGNORECASE" => ctx.new_int(IGNORECASE), "I" => ctx.new_int(IGNORECASE), "LOCALE" => ctx.new_int(LOCALE),